ai_sdk_core/text/
stream.rs

1use super::GenerationConfig;
2use crate::impl_builder_core;
3use crate::text::accumulate_usage;
4use crate::tool::ToolExecutor;
5use crate::Result;
6use crate::{error::GenerateError, text::convert_content};
7use ai_sdk_provider::language_model::{
8    Content, FinishReason, LanguageModel, Message, StreamPart, TextPart, Tool as ProviderTool,
9    ToolCallPart, ToolChoice, ToolResultPart, Usage,
10};
11use async_stream::stream;
12use futures::future::BoxFuture;
13use std::future::IntoFuture;
14use std::pin::Pin;
15use std::sync::Arc;
16use tokio_stream::{Stream, StreamExt};
17
18// -----------------------------------------------------------------------------
19// Builder Definition
20// -----------------------------------------------------------------------------
21
22/// Builder for streaming text generation.
23///
24/// This builder allows configuring the model, prompt, tools, and other settings
25/// before starting the stream.
26pub struct StreamTextBuilder<M, P> {
27    model: M,
28    prompt: P,
29    config: GenerationConfig,
30}
31
32// Use the macro to implement new(), setters, and transitions
33impl_builder_core!(StreamTextBuilder);
34
35// -----------------------------------------------------------------------------
36// Execution Logic
37// -----------------------------------------------------------------------------
38
39impl StreamTextBuilder<Arc<dyn LanguageModel>, Vec<Message>> {
40    /// Execute the streaming text generation.
41    ///
42    /// This returns a `StreamTextResult` which contains the stream of events.
43    ///
44    /// # Returns
45    /// A `Result` containing `StreamTextResult` on success, or a `GenerateError` on failure.
46    pub async fn execute(self) -> Result<StreamTextResult> {
47        let model = self.model;
48        let messages = self.prompt;
49        let config = self.config;
50
51        // Streaming defaults to more steps if not set, or use config
52        let max_steps = if config.max_steps == 1 {
53            5
54        } else {
55            config.max_steps
56        };
57
58        let tool_executor = ToolExecutor::new(config.tools);
59
60        let stream_impl = create_multi_step_stream(
61            model,
62            messages,
63            tool_executor,
64            max_steps,
65            config.temperature,
66            config.max_tokens,
67            config.retry_policy,
68        );
69
70        Ok(StreamTextResult {
71            stream: Box::pin(stream_impl),
72        })
73    }
74}
75
76// -----------------------------------------------------------------------------
77// IntoFuture
78// -----------------------------------------------------------------------------
79
80impl IntoFuture for StreamTextBuilder<Arc<dyn LanguageModel>, Vec<Message>> {
81    type Output = Result<StreamTextResult>;
82    type IntoFuture = BoxFuture<'static, Self::Output>;
83
84    fn into_future(self) -> Self::IntoFuture {
85        Box::pin(self.execute())
86    }
87}
88
89// -----------------------------------------------------------------------------
90// Helper Structs
91// -----------------------------------------------------------------------------
92
93/// Result of a streaming text generation.
94///
95/// This struct wraps the underlying stream and provides methods to access it.
96pub struct StreamTextResult {
97    stream: Pin<Box<dyn Stream<Item = Result<TextStreamPart>> + Send>>,
98}
99
100impl StreamTextResult {
101    /// Returns a mutable reference to the stream.
102    pub fn stream_mut(&mut self) -> Pin<&mut (dyn Stream<Item = Result<TextStreamPart>> + Send)> {
103        self.stream.as_mut()
104    }
105
106    /// Consumes the result and returns the underlying stream.
107    pub fn into_stream(self) -> Pin<Box<dyn Stream<Item = Result<TextStreamPart>> + Send>> {
108        self.stream
109    }
110}
111
112/// A part of the text stream.
113///
114/// This enum represents the different types of events that can occur during streaming.
115#[derive(Debug, Clone)]
116pub enum TextStreamPart {
117    /// A delta of text generated by the model.
118    TextDelta(String),
119    /// A tool call generated by the model.
120    ToolCall(ToolCallPart),
121    /// The result of a tool execution.
122    ToolResult(ToolResultPart),
123    /// Indicates that a step (LLM call) has finished.
124    StepFinish {
125        /// The index of the finished step.
126        step_index: u32,
127        /// The reason why the step finished.
128        finish_reason: FinishReason,
129    },
130    /// Indicates that the entire generation process has finished.
131    Finish {
132        /// The total usage across all steps.
133        total_usage: Usage,
134    },
135}
136
137/// Streams text from a language model.
138///
139/// This function creates a builder that allows you to configure the model, prompt,
140/// tools, and other parameters.
141///
142/// # Example
143///
144/// ```rust,ignore
145/// use ai_sdk_core::stream_text;
146/// use ai_sdk_openai::openai;
147/// use tokio_stream::StreamExt;
148///
149/// let mut result = stream_text()
150///     .model(openai("gpt-4"))
151///     .prompt("Write a poem")
152///     .execute()
153///     .await?;
154///
155/// let mut stream = result.into_stream();
156/// while let Some(part) = stream.next().await {
157///     match part? {
158///         TextStreamPart::TextDelta(delta) => print!("{}", delta),
159///         _ => {}
160///     }
161/// }
162/// ```
163pub fn stream_text() -> StreamTextBuilder<(), ()> {
164    StreamTextBuilder::new()
165}
166
167// -----------------------------------------------------------------------------
168// Stream Logic
169// -----------------------------------------------------------------------------
170
171fn create_multi_step_stream(
172    model: Arc<dyn LanguageModel>,
173    initial_messages: Vec<Message>,
174    tool_executor: ToolExecutor,
175    max_steps: u32,
176    temperature: Option<f32>,
177    max_tokens: Option<u32>,
178    retry_policy: crate::retry::RetryPolicy,
179) -> impl Stream<Item = Result<TextStreamPart>> {
180    stream! {
181        let mut messages = initial_messages;
182        let mut total_usage = Usage {
183            input_tokens: Some(0), output_tokens: Some(0), total_tokens: Some(0),
184            reasoning_tokens: None, cached_input_tokens: None,
185        };
186
187        for step_index in 0..max_steps {
188            let mut builder = model.stream(messages.clone());
189
190            if let Some(temperature) = temperature {
191                builder = builder.temperature(temperature);
192            }
193            if let Some(max_tokens) = max_tokens {
194                builder = builder.max_tokens(max_tokens);
195            }
196
197            if !tool_executor.tools().is_empty() {
198                let tool_defs = tool_executor.tool_definitions();
199                builder = builder
200                    .tools(tool_defs.into_iter().map(ProviderTool::Function).collect())
201                    .tool_choice(ToolChoice::Auto);
202            }
203
204            // Retry the initial connection
205            let stream_response = retry_policy.retry(|| {
206                let builder = builder.clone();
207                async move { builder.await }
208            }).await;
209
210            let stream_response = match stream_response {
211                Ok(r) => r,
212                Err(e) => { yield Err(GenerateError::ProviderError(e).into()); return; }
213            };
214
215            let mut step_stream = stream_response.stream;
216            let mut step_content = Vec::new();
217            let mut text_acc = String::new();
218            let mut tool_calls = Vec::new();
219            let mut finish_reason = None;
220
221            while let Some(part_result) = step_stream.next().await {
222                let part = match part_result {
223                    Ok(p) => p,
224                    Err(e) => { yield Err(GenerateError::StreamError(format!("{:?}", e)).into()); return; }
225                };
226
227                match part {
228                    StreamPart::TextDelta { delta, .. } => {
229                        yield Ok(TextStreamPart::TextDelta(delta.clone()));
230                        text_acc.push_str(&delta);
231                    }
232                    StreamPart::ToolCall(tc) => {
233                        tool_calls.push(tc.clone());
234                        step_content.push(Content::ToolCall(tc.clone()));
235                        yield Ok(TextStreamPart::ToolCall(tc));
236                    }
237                    StreamPart::Finish { finish_reason: fr, usage, .. } => {
238                        finish_reason = Some(fr);
239                        accumulate_usage(&mut total_usage, &usage);
240                    }
241                    _ => {}
242                }
243            }
244
245            if !text_acc.is_empty() {
246                step_content.push(Content::Text(TextPart { text: text_acc, provider_metadata: None }));
247            }
248
249            let fr = finish_reason.unwrap_or(FinishReason::Stop);
250            yield Ok(TextStreamPart::StepFinish { step_index, finish_reason: fr });
251
252            if tool_calls.is_empty() || fr != FinishReason::ToolCalls {
253                break;
254            }
255
256            let tool_results = tool_executor.execute_tools(tool_calls.clone()).await;
257            for result in &tool_results {
258                yield Ok(TextStreamPart::ToolResult(result.clone()));
259            }
260
261            messages.push(Message::Assistant {
262                content: step_content.into_iter().map(convert_content).collect(),
263            });
264            messages.push(Message::Tool { content: tool_results });
265        }
266
267        yield Ok(TextStreamPart::Finish { total_usage });
268    }
269}