ai_sdk_core/text/
stream.rs1use super::GenerationConfig;
2use crate::impl_builder_core;
3use crate::text::accumulate_usage;
4use crate::tool::ToolExecutor;
5use crate::Result;
6use crate::{error::GenerateError, text::convert_content};
7use ai_sdk_provider::language_model::{
8 Content, FinishReason, LanguageModel, Message, StreamPart, TextPart, Tool as ProviderTool,
9 ToolCallPart, ToolChoice, ToolResultPart, Usage,
10};
11use async_stream::stream;
12use futures::future::BoxFuture;
13use std::future::IntoFuture;
14use std::pin::Pin;
15use std::sync::Arc;
16use tokio_stream::{Stream, StreamExt};
17
18pub struct StreamTextBuilder<M, P> {
27 model: M,
28 prompt: P,
29 config: GenerationConfig,
30}
31
32impl_builder_core!(StreamTextBuilder);
34
35impl StreamTextBuilder<Arc<dyn LanguageModel>, Vec<Message>> {
40 pub async fn execute(self) -> Result<StreamTextResult> {
47 let model = self.model;
48 let messages = self.prompt;
49 let config = self.config;
50
51 let max_steps = if config.max_steps == 1 {
53 5
54 } else {
55 config.max_steps
56 };
57
58 let tool_executor = ToolExecutor::new(config.tools);
59
60 let stream_impl = create_multi_step_stream(
61 model,
62 messages,
63 tool_executor,
64 max_steps,
65 config.temperature,
66 config.max_tokens,
67 config.retry_policy,
68 );
69
70 Ok(StreamTextResult {
71 stream: Box::pin(stream_impl),
72 })
73 }
74}
75
76impl IntoFuture for StreamTextBuilder<Arc<dyn LanguageModel>, Vec<Message>> {
81 type Output = Result<StreamTextResult>;
82 type IntoFuture = BoxFuture<'static, Self::Output>;
83
84 fn into_future(self) -> Self::IntoFuture {
85 Box::pin(self.execute())
86 }
87}
88
89pub struct StreamTextResult {
97 stream: Pin<Box<dyn Stream<Item = Result<TextStreamPart>> + Send>>,
98}
99
100impl StreamTextResult {
101 pub fn stream_mut(&mut self) -> Pin<&mut (dyn Stream<Item = Result<TextStreamPart>> + Send)> {
103 self.stream.as_mut()
104 }
105
106 pub fn into_stream(self) -> Pin<Box<dyn Stream<Item = Result<TextStreamPart>> + Send>> {
108 self.stream
109 }
110}
111
112#[derive(Debug, Clone)]
116pub enum TextStreamPart {
117 TextDelta(String),
119 ToolCall(ToolCallPart),
121 ToolResult(ToolResultPart),
123 StepFinish {
125 step_index: u32,
127 finish_reason: FinishReason,
129 },
130 Finish {
132 total_usage: Usage,
134 },
135}
136
137pub fn stream_text() -> StreamTextBuilder<(), ()> {
164 StreamTextBuilder::new()
165}
166
167fn create_multi_step_stream(
172 model: Arc<dyn LanguageModel>,
173 initial_messages: Vec<Message>,
174 tool_executor: ToolExecutor,
175 max_steps: u32,
176 temperature: Option<f32>,
177 max_tokens: Option<u32>,
178 retry_policy: crate::retry::RetryPolicy,
179) -> impl Stream<Item = Result<TextStreamPart>> {
180 stream! {
181 let mut messages = initial_messages;
182 let mut total_usage = Usage {
183 input_tokens: Some(0), output_tokens: Some(0), total_tokens: Some(0),
184 reasoning_tokens: None, cached_input_tokens: None,
185 };
186
187 for step_index in 0..max_steps {
188 let mut builder = model.stream(messages.clone());
189
190 if let Some(temperature) = temperature {
191 builder = builder.temperature(temperature);
192 }
193 if let Some(max_tokens) = max_tokens {
194 builder = builder.max_tokens(max_tokens);
195 }
196
197 if !tool_executor.tools().is_empty() {
198 let tool_defs = tool_executor.tool_definitions();
199 builder = builder
200 .tools(tool_defs.into_iter().map(ProviderTool::Function).collect())
201 .tool_choice(ToolChoice::Auto);
202 }
203
204 let stream_response = retry_policy.retry(|| {
206 let builder = builder.clone();
207 async move { builder.await }
208 }).await;
209
210 let stream_response = match stream_response {
211 Ok(r) => r,
212 Err(e) => { yield Err(GenerateError::ProviderError(e).into()); return; }
213 };
214
215 let mut step_stream = stream_response.stream;
216 let mut step_content = Vec::new();
217 let mut text_acc = String::new();
218 let mut tool_calls = Vec::new();
219 let mut finish_reason = None;
220
221 while let Some(part_result) = step_stream.next().await {
222 let part = match part_result {
223 Ok(p) => p,
224 Err(e) => { yield Err(GenerateError::StreamError(format!("{:?}", e)).into()); return; }
225 };
226
227 match part {
228 StreamPart::TextDelta { delta, .. } => {
229 yield Ok(TextStreamPart::TextDelta(delta.clone()));
230 text_acc.push_str(&delta);
231 }
232 StreamPart::ToolCall(tc) => {
233 tool_calls.push(tc.clone());
234 step_content.push(Content::ToolCall(tc.clone()));
235 yield Ok(TextStreamPart::ToolCall(tc));
236 }
237 StreamPart::Finish { finish_reason: fr, usage, .. } => {
238 finish_reason = Some(fr);
239 accumulate_usage(&mut total_usage, &usage);
240 }
241 _ => {}
242 }
243 }
244
245 if !text_acc.is_empty() {
246 step_content.push(Content::Text(TextPart { text: text_acc, provider_metadata: None }));
247 }
248
249 let fr = finish_reason.unwrap_or(FinishReason::Stop);
250 yield Ok(TextStreamPart::StepFinish { step_index, finish_reason: fr });
251
252 if tool_calls.is_empty() || fr != FinishReason::ToolCalls {
253 break;
254 }
255
256 let tool_results = tool_executor.execute_tools(tool_calls.clone()).await;
257 for result in &tool_results {
258 yield Ok(TextStreamPart::ToolResult(result.clone()));
259 }
260
261 messages.push(Message::Assistant {
262 content: step_content.into_iter().map(convert_content).collect(),
263 });
264 messages.push(Message::Tool { content: tool_results });
265 }
266
267 yield Ok(TextStreamPart::Finish { total_usage });
268 }
269}