ai_sdk_core/text/
generate.rs1use super::{GenerationConfig, OnPreliminaryToolResultCallback};
2use crate::error::GenerateError;
3use crate::impl_builder_core;
4use crate::text::{accumulate_usage, convert_content};
5use crate::tool::ToolExecutor;
6use crate::Result;
7use ai_sdk_provider::language_model::{
8 Content, FinishReason, LanguageModel, Message, Tool as ProviderTool, ToolCallPart, ToolChoice,
9 ToolResultPart, Usage,
10};
11use futures::future::BoxFuture;
12use std::future::IntoFuture;
13use std::sync::Arc;
14
15pub struct GenerateTextBuilder<M, P> {
24 model: M,
25 prompt: P,
26 config: GenerationConfig,
27}
28
29impl_builder_core!(GenerateTextBuilder);
31
32impl<M, P> GenerateTextBuilder<M, P> {
34 pub fn on_preliminary_tool_result(mut self, callback: OnPreliminaryToolResultCallback) -> Self {
40 self.config.on_preliminary_tool_result = Some(callback);
41 self
42 }
43}
44
45impl GenerateTextBuilder<Arc<dyn LanguageModel>, Vec<Message>> {
50 pub async fn execute(self) -> Result<GenerateTextResult> {
58 let model = self.model;
59 let mut messages = self.prompt;
60
61 let GenerationConfig {
63 tools,
64 max_steps,
65 temperature,
66 max_tokens,
67 retry_policy,
68 on_preliminary_tool_result,
69 } = self.config;
70
71 let tool_executor = ToolExecutor::new(tools);
72 let mut steps = Vec::new();
73 let mut total_usage = Usage {
74 input_tokens: Some(0),
75 output_tokens: Some(0),
76 total_tokens: Some(0),
77 reasoning_tokens: None,
78 cached_input_tokens: None,
79 };
80
81 for step_index in 0..max_steps {
82 let mut builder = model.generate(messages.clone());
83
84 if let Some(temperature) = temperature {
85 builder = builder.temperature(temperature);
86 }
87 if let Some(max_tokens) = max_tokens {
88 builder = builder.max_tokens(max_tokens);
89 }
90
91 if !tool_executor.tools().is_empty() {
92 let tool_defs = tool_executor.tool_definitions();
93 builder = builder
94 .tools(tool_defs.into_iter().map(ProviderTool::Function).collect())
95 .tool_choice(ToolChoice::Auto);
96 }
97
98 let response = retry_policy
100 .retry(|| {
101 let builder = builder.clone();
102 async move { builder.await }
103 })
104 .await
105 .map_err(GenerateError::ProviderError)?;
106
107 accumulate_usage(&mut total_usage, &response.usage);
109
110 let tool_calls = extract_tool_calls(&response.content);
111
112 let mut tool_results = Vec::new();
114 if !tool_calls.is_empty() && response.finish_reason == FinishReason::ToolCalls {
115 if let Some(ref callback) = on_preliminary_tool_result {
116 for tool_call in &tool_calls {
117 let callback = callback.clone();
118 let result = tool_executor
119 .execute_tool_with_stream(tool_call.clone(), move |preliminary| {
120 let cb = callback.clone();
121 let preliminary = preliminary.clone();
122 tokio::spawn(async move {
123 cb(preliminary).await;
124 });
125 })
126 .await;
127 tool_results.push(result);
128 }
129 } else {
130 tool_results = tool_executor.execute_tools(tool_calls.clone()).await;
131 };
132 }
133
134 steps.push(StepResult {
135 step_index,
136 response_content: response.content.clone(),
137 tool_calls: tool_calls.clone(),
138 tool_results: tool_results.clone(),
139 finish_reason: response.finish_reason,
140 usage: response.usage.clone(),
141 });
142
143 if tool_calls.is_empty() || response.finish_reason != FinishReason::ToolCalls {
144 break;
145 }
146
147 messages.push(Message::Assistant {
149 content: response.content.into_iter().map(convert_content).collect(),
150 });
151 messages.push(Message::Tool {
152 content: tool_results,
153 });
154 }
155
156 Ok(GenerateTextResult { steps, total_usage })
157 }
158}
159
160impl IntoFuture for GenerateTextBuilder<Arc<dyn LanguageModel>, Vec<Message>> {
165 type Output = Result<GenerateTextResult>;
166 type IntoFuture = BoxFuture<'static, Self::Output>;
167
168 fn into_future(self) -> Self::IntoFuture {
169 Box::pin(self.execute())
170 }
171}
172
173pub struct GenerateTextResult {
179 steps: Vec<StepResult>,
182 total_usage: Usage,
184}
185
186impl GenerateTextResult {
187 pub fn text(&self) -> String {
191 self.steps
192 .last()
193 .map(|step| extract_text(&step.response_content))
194 .unwrap_or_default()
195 }
196
197 pub fn tool_calls(&self) -> &[ToolCallPart] {
199 self.steps
200 .last()
201 .map(|step| step.tool_calls.as_slice())
202 .unwrap_or_default()
203 }
204
205 pub fn tool_results(&self) -> &[ToolResultPart] {
207 self.steps
208 .last()
209 .map(|step| step.tool_results.as_slice())
210 .unwrap_or_default()
211 }
212
213 pub fn steps(&self) -> &[StepResult] {
215 &self.steps
216 }
217
218 pub fn usage(&self) -> &Usage {
220 &self.total_usage
221 }
222
223 pub fn finish_reason(&self) -> &FinishReason {
225 self.steps
226 .last()
227 .map(|s| &s.finish_reason)
228 .unwrap_or(&FinishReason::Stop)
229 }
230}
231
232#[derive(Debug, Clone)]
234pub struct StepResult {
235 pub step_index: u32,
237 pub response_content: Vec<Content>,
239 pub tool_calls: Vec<ToolCallPart>,
241 pub tool_results: Vec<ToolResultPart>,
243 pub finish_reason: FinishReason,
245 pub usage: Usage,
247}
248
249pub fn generate_text() -> GenerateTextBuilder<(), ()> {
269 GenerateTextBuilder::new()
270}
271
272fn extract_tool_calls(content: &[Content]) -> Vec<ToolCallPart> {
274 content
275 .iter()
276 .filter_map(|c| match c {
277 Content::ToolCall(tc) => Some(tc.clone()),
278 _ => None,
279 })
280 .collect()
281}
282
283fn extract_text(content: &[Content]) -> String {
284 content
285 .iter()
286 .filter_map(|c| match c {
287 Content::Text(t) => Some(t.text.clone()),
288 _ => None,
289 })
290 .collect::<Vec<_>>()
291 .join("")
292}