ai_sdk_core/text/
generate.rs

1use super::{GenerationConfig, OnPreliminaryToolResultCallback};
2use crate::error::GenerateError;
3use crate::impl_builder_core;
4use crate::text::{accumulate_usage, convert_content};
5use crate::tool::ToolExecutor;
6use crate::Result;
7use ai_sdk_provider::language_model::{
8    Content, FinishReason, LanguageModel, Message, Tool as ProviderTool, ToolCallPart, ToolChoice,
9    ToolResultPart, Usage,
10};
11use futures::future::BoxFuture;
12use std::future::IntoFuture;
13use std::sync::Arc;
14
15// -----------------------------------------------------------------------------
16// Builder Definition
17// -----------------------------------------------------------------------------
18
19/// Builder for text generation.
20///
21/// This builder allows configuring the model, prompt, tools, and other settings
22/// before executing the generation.
23pub struct GenerateTextBuilder<M, P> {
24    model: M,
25    prompt: P,
26    config: GenerationConfig,
27}
28
29// Use the macro to implement new(), setters, and transitions
30impl_builder_core!(GenerateTextBuilder);
31
32// Implement methods specific to this builder
33impl<M, P> GenerateTextBuilder<M, P> {
34    /// Set a callback for preliminary tool results.
35    ///
36    /// This is useful when you want to receive updates about tool execution
37    /// (e.g., partial outputs from long-running tools) while the tool is still running,
38    /// even in a non-streaming generation context.
39    pub fn on_preliminary_tool_result(mut self, callback: OnPreliminaryToolResultCallback) -> Self {
40        self.config.on_preliminary_tool_result = Some(callback);
41        self
42    }
43}
44
45// -----------------------------------------------------------------------------
46// Execution Logic
47// -----------------------------------------------------------------------------
48
49impl GenerateTextBuilder<Arc<dyn LanguageModel>, Vec<Message>> {
50    /// Execute the text generation.
51    ///
52    /// This will send the prompt to the model, handle tool calls if tools are provided,
53    /// and return the final result.
54    ///
55    /// # Returns
56    /// A `Result` containing `GenerateTextResult` on success, or a `GenerateError` on failure.
57    pub async fn execute(self) -> Result<GenerateTextResult> {
58        let model = self.model;
59        let mut messages = self.prompt;
60
61        // Unpack config
62        let GenerationConfig {
63            tools,
64            max_steps,
65            temperature,
66            max_tokens,
67            retry_policy,
68            on_preliminary_tool_result,
69        } = self.config;
70
71        let tool_executor = ToolExecutor::new(tools);
72        let mut steps = Vec::new();
73        let mut total_usage = Usage {
74            input_tokens: Some(0),
75            output_tokens: Some(0),
76            total_tokens: Some(0),
77            reasoning_tokens: None,
78            cached_input_tokens: None,
79        };
80
81        for step_index in 0..max_steps {
82            let mut builder = model.generate(messages.clone());
83
84            if let Some(temperature) = temperature {
85                builder = builder.temperature(temperature);
86            }
87            if let Some(max_tokens) = max_tokens {
88                builder = builder.max_tokens(max_tokens);
89            }
90
91            if !tool_executor.tools().is_empty() {
92                let tool_defs = tool_executor.tool_definitions();
93                builder = builder
94                    .tools(tool_defs.into_iter().map(ProviderTool::Function).collect())
95                    .tool_choice(ToolChoice::Auto);
96            }
97
98            // Retry policy
99            let response = retry_policy
100                .retry(|| {
101                    let builder = builder.clone();
102                    async move { builder.await }
103                })
104                .await
105                .map_err(GenerateError::ProviderError)?;
106
107            // Track usage
108            accumulate_usage(&mut total_usage, &response.usage);
109
110            let tool_calls = extract_tool_calls(&response.content);
111
112            // Execute tools if any
113            let mut tool_results = Vec::new();
114            if !tool_calls.is_empty() && response.finish_reason == FinishReason::ToolCalls {
115                if let Some(ref callback) = on_preliminary_tool_result {
116                    for tool_call in &tool_calls {
117                        let callback = callback.clone();
118                        let result = tool_executor
119                            .execute_tool_with_stream(tool_call.clone(), move |preliminary| {
120                                let cb = callback.clone();
121                                let preliminary = preliminary.clone();
122                                tokio::spawn(async move {
123                                    cb(preliminary).await;
124                                });
125                            })
126                            .await;
127                        tool_results.push(result);
128                    }
129                } else {
130                    tool_results = tool_executor.execute_tools(tool_calls.clone()).await;
131                };
132            }
133
134            steps.push(StepResult {
135                step_index,
136                response_content: response.content.clone(),
137                tool_calls: tool_calls.clone(),
138                tool_results: tool_results.clone(),
139                finish_reason: response.finish_reason,
140                usage: response.usage.clone(),
141            });
142
143            if tool_calls.is_empty() || response.finish_reason != FinishReason::ToolCalls {
144                break;
145            }
146
147            // Update history
148            messages.push(Message::Assistant {
149                content: response.content.into_iter().map(convert_content).collect(),
150            });
151            messages.push(Message::Tool {
152                content: tool_results,
153            });
154        }
155
156        Ok(GenerateTextResult { steps, total_usage })
157    }
158}
159
160// -----------------------------------------------------------------------------
161// IntoFuture
162// -----------------------------------------------------------------------------
163
164impl IntoFuture for GenerateTextBuilder<Arc<dyn LanguageModel>, Vec<Message>> {
165    type Output = Result<GenerateTextResult>;
166    type IntoFuture = BoxFuture<'static, Self::Output>;
167
168    fn into_future(self) -> Self::IntoFuture {
169        Box::pin(self.execute())
170    }
171}
172
173// -----------------------------------------------------------------------------
174// Helper Structs & Functions
175// -----------------------------------------------------------------------------
176
177/// Result of a text generation call.
178pub struct GenerateTextResult {
179    /// The sequence of steps executed during the generation.
180    /// Each step represents a single call to the language model.
181    steps: Vec<StepResult>,
182    /// The total token usage across all steps.
183    total_usage: Usage,
184}
185
186impl GenerateTextResult {
187    /// Returns the generated text from the last step.
188    ///
189    /// If the last step produced multiple text parts, they are concatenated.
190    pub fn text(&self) -> String {
191        self.steps
192            .last()
193            .map(|step| extract_text(&step.response_content))
194            .unwrap_or_default()
195    }
196
197    /// Returns the tool calls from the last step.
198    pub fn tool_calls(&self) -> &[ToolCallPart] {
199        self.steps
200            .last()
201            .map(|step| step.tool_calls.as_slice())
202            .unwrap_or_default()
203    }
204
205    /// Returns the tool results from the last step.
206    pub fn tool_results(&self) -> &[ToolResultPart] {
207        self.steps
208            .last()
209            .map(|step| step.tool_results.as_slice())
210            .unwrap_or_default()
211    }
212
213    /// Returns the list of steps executed.
214    pub fn steps(&self) -> &[StepResult] {
215        &self.steps
216    }
217
218    /// Returns the total usage across all steps.
219    pub fn usage(&self) -> &Usage {
220        &self.total_usage
221    }
222
223    /// Returns the finish reason of the last step.
224    pub fn finish_reason(&self) -> &FinishReason {
225        self.steps
226            .last()
227            .map(|s| &s.finish_reason)
228            .unwrap_or(&FinishReason::Stop)
229    }
230}
231
232/// Result of a single step in the generation process.
233#[derive(Debug, Clone)]
234pub struct StepResult {
235    /// The index of this step (0-based).
236    pub step_index: u32,
237    /// The content returned by the model in this step.
238    pub response_content: Vec<Content>,
239    /// The tool calls generated in this step.
240    pub tool_calls: Vec<ToolCallPart>,
241    /// The tool results produced in this step (executed after the model response).
242    pub tool_results: Vec<ToolResultPart>,
243    /// The reason why the model stopped generating in this step.
244    pub finish_reason: FinishReason,
245    /// The token usage for this step.
246    pub usage: Usage,
247}
248
249/// Generates text using a language model.
250///
251/// This function creates a builder that allows you to configure the model, prompt,
252/// tools, and other parameters.
253///
254/// # Example
255///
256/// ```rust,ignore
257/// use ai_sdk_core::generate_text;
258/// use ai_sdk_openai::openai;
259///
260/// let result = generate_text()
261///     .model(openai("gpt-4"))
262///     .prompt("Hello, world!")
263///     .execute()
264///     .await?;
265///
266/// println!("{}", result.text());
267/// ```
268pub fn generate_text() -> GenerateTextBuilder<(), ()> {
269    GenerateTextBuilder::new()
270}
271
272// Internal Helpers
273fn extract_tool_calls(content: &[Content]) -> Vec<ToolCallPart> {
274    content
275        .iter()
276        .filter_map(|c| match c {
277            Content::ToolCall(tc) => Some(tc.clone()),
278            _ => None,
279        })
280        .collect()
281}
282
283fn extract_text(content: &[Content]) -> String {
284    content
285        .iter()
286        .filter_map(|c| match c {
287            Content::Text(t) => Some(t.text.clone()),
288            _ => None,
289        })
290        .collect::<Vec<_>>()
291        .join("")
292}