rig/
streaming.rs

1//! This module provides functionality for working with streaming completion models.
2//! It provides traits and types for generating streaming completion requests and
3//! handling streaming completion responses.
4//!
5//! The main traits defined in this module are:
6//! - [StreamingPrompt]: Defines a high-level streaming LLM one-shot prompt interface
7//! - [StreamingChat]: Defines a high-level streaming LLM chat interface with history
8//! - [StreamingCompletion]: Defines a low-level streaming LLM completion interface
9//! - [StreamingCompletionModel]: Defines a streaming completion model interface
10//!
11
12use crate::agent::Agent;
13use crate::completion::{
14    CompletionError, CompletionModel, CompletionRequest, CompletionRequestBuilder, Message,
15};
16use futures::{Stream, StreamExt};
17use std::boxed::Box;
18use std::fmt::{Display, Formatter};
19use std::future::Future;
20use std::pin::Pin;
21
22/// Enum representing a streaming chunk from the model
23#[derive(Debug)]
24pub enum StreamingChoice {
25    /// A text chunk from a message response
26    Message(String),
27
28    /// A tool call response chunk
29    ToolCall(String, String, serde_json::Value),
30}
31
32impl Display for StreamingChoice {
33    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
34        match self {
35            StreamingChoice::Message(text) => write!(f, "{}", text),
36            StreamingChoice::ToolCall(name, id, params) => {
37                write!(f, "Tool call: {} {} {:?}", name, id, params)
38            }
39        }
40    }
41}
42
43#[cfg(not(target_arch = "wasm32"))]
44pub type StreamingResult =
45    Pin<Box<dyn Stream<Item = Result<StreamingChoice, CompletionError>> + Send>>;
46
47#[cfg(target_arch = "wasm32")]
48pub type StreamingResult = Pin<Box<dyn Stream<Item = Result<StreamingChoice, CompletionError>>>>;
49
50/// Trait for high-level streaming prompt interface
51pub trait StreamingPrompt: Send + Sync {
52    /// Stream a simple prompt to the model
53    fn stream_prompt(
54        &self,
55        prompt: &str,
56    ) -> impl Future<Output = Result<StreamingResult, CompletionError>>;
57}
58
59/// Trait for high-level streaming chat interface
60pub trait StreamingChat: Send + Sync {
61    /// Stream a chat with history to the model
62    fn stream_chat(
63        &self,
64        prompt: &str,
65        chat_history: Vec<Message>,
66    ) -> impl Future<Output = Result<StreamingResult, CompletionError>>;
67}
68
69/// Trait for low-level streaming completion interface
70pub trait StreamingCompletion<M: StreamingCompletionModel> {
71    /// Generate a streaming completion from a request
72    fn stream_completion(
73        &self,
74        prompt: &str,
75        chat_history: Vec<Message>,
76    ) -> impl Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>;
77}
78
79/// Trait defining a streaming completion model
80pub trait StreamingCompletionModel: CompletionModel {
81    /// Stream a completion response for the given request
82    fn stream(
83        &self,
84        request: CompletionRequest,
85    ) -> impl Future<Output = Result<StreamingResult, CompletionError>>;
86}
87
88/// helper function to stream a completion request to stdout
89pub async fn stream_to_stdout<M: StreamingCompletionModel>(
90    agent: Agent<M>,
91    stream: &mut StreamingResult,
92) -> Result<(), std::io::Error> {
93    print!("Response: ");
94    while let Some(chunk) = stream.next().await {
95        match chunk {
96            Ok(StreamingChoice::Message(text)) => {
97                print!("{}", text);
98                std::io::Write::flush(&mut std::io::stdout())?;
99            }
100            Ok(StreamingChoice::ToolCall(name, _, params)) => {
101                let res = agent
102                    .tools
103                    .call(&name, params.to_string())
104                    .await
105                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
106                println!("\nResult: {}", res);
107            }
108            Err(e) => {
109                eprintln!("Error: {}", e);
110                break;
111            }
112        }
113    }
114    println!(); // New line after streaming completes
115
116    Ok(())
117}