openai_agents_rust/
realtime.rs

1use async_trait::async_trait;
2use bytes::Bytes;
3use futures_util::StreamExt;
4use reqwest::Client;
5use serde_json::Value;
6use std::pin::Pin;
7
8use crate::config::Config;
9use crate::error::AgentError;
10
11/// Trait for real‑time streaming capabilities (e.g., token streams over SSE).
12#[async_trait]
13pub trait Realtime: Send + Sync {
14    /// Start a streaming session and return a handle that yields streamed text deltas.
15    async fn start_stream(&self) -> Result<Box<dyn StreamItem>, AgentError>;
16}
17
18/// Trait representing a single item yielded by a real‑time stream.
19#[async_trait]
20pub trait StreamItem: Send + Sync {
21    /// Retrieve the next chunk of data. Returns `None` when the stream ends.
22    async fn next(&mut self) -> Result<Option<String>, AgentError>;
23}
24
25/// OpenAI‑compatible Chat Completions streaming client (SSE, stream=true).
26/// Construct with the prompt/messages to stream and then call `start_stream`.
27pub struct OpenAiChatRealtime {
28    client: Client,
29    base_url: String,
30    auth_token: Option<String>,
31    model: String,
32    messages: Vec<Value>,
33    // optional parameters
34    max_tokens: Option<i32>,
35    temperature: Option<f32>,
36}
37
38impl OpenAiChatRealtime {
39    pub fn new_with_messages(config: Config, messages: Vec<Value>) -> Self {
40        let client = Client::builder()
41            .user_agent("openai-agents-rust")
42            .build()
43            .expect("Failed to build reqwest client");
44        let auth_token = if config.api_key.is_empty() {
45            None
46        } else {
47            Some(config.api_key.clone())
48        };
49        Self {
50            client,
51            base_url: config.base_url.clone(),
52            auth_token,
53            model: config.model.clone(),
54            messages,
55            max_tokens: Some(512),
56            temperature: Some(0.2),
57        }
58    }
59
60    pub fn new_simple(config: Config, prompt: &str) -> Self {
61        let messages = vec![serde_json::json!({"role":"user","content":prompt})];
62        Self::new_with_messages(config, messages)
63    }
64
65    fn url(&self) -> String {
66        format!("{}/chat/completions", self.base_url.trim_end_matches('/'))
67    }
68}
69
70#[async_trait]
71impl Realtime for OpenAiChatRealtime {
72    async fn start_stream(&self) -> Result<Box<dyn StreamItem>, AgentError> {
73        let mut body = serde_json::json!({
74            "model": self.model,
75            "messages": self.messages,
76            "stream": true,
77        });
78        if let Some(mt) = self.max_tokens {
79            body["max_tokens"] = serde_json::json!(mt);
80        }
81        if let Some(t) = self.temperature {
82            body["temperature"] = serde_json::json!(t);
83        }
84
85        let mut req = self.client.post(self.url());
86        if let Some(token) = &self.auth_token {
87            req = req.bearer_auth(token);
88        }
89        let resp = req.json(&body).send().await.map_err(AgentError::from)?;
90        let status = resp.status();
91        if !status.is_success() {
92            let text = resp.text().await.unwrap_or_default();
93            return Err(AgentError::Other(format!(
94                "realtime stream failed: HTTP {} — {}",
95                status, text
96            )));
97        }
98
99        let item = SseStreamItem::new(resp);
100        Ok(Box::new(item))
101    }
102}
103
104/// StreamItem backed by parsing SSE "data:" lines from an HTTP response.
105struct SseStreamItem {
106    stream: tokio::sync::Mutex<
107        Pin<Box<dyn futures_core::Stream<Item = Result<String, AgentError>> + Send>>,
108    >,
109}
110
111impl SseStreamItem {
112    fn new(resp: reqwest::Response) -> Self {
113        let byte_stream = resp.bytes_stream();
114        let s = async_stream::try_stream! {
115                let mut buf: Vec<u8> = Vec::new();
116                futures_util::pin_mut!(byte_stream);
117                while let Some(chunk) = byte_stream.next().await {
118                    let chunk: Bytes = chunk.map_err(AgentError::from)?;
119                    buf.extend_from_slice(&chunk);
120                    // process complete lines
121                    loop {
122                        if let Some(pos) = buf.iter().position(|b| *b == b'\n') {
123                            let line = buf.drain(..=pos).collect::<Vec<u8>>();
124                            let line = String::from_utf8_lossy(&line).to_string();
125                            let line = line.trim();
126                            if line.is_empty() { continue; }
127                            if let Some(rest) = line.strip_prefix("data: ") {
128                                let data = rest.trim();
129                                if data == "[DONE]" { break; }
130                                // Try parse JSON, extract text deltas
131                                if let Ok(v) = serde_json::from_str::<Value>(data) {
132                                    // OpenAI: choices[0].delta.content or choices[0].text
133                                    let maybe = v
134                                        .get("choices").and_then(|c| c.as_array()).and_then(|arr| arr.get(0))
135                                        .and_then(|c0| c0.get("delta").and_then(|d| d.get("content")).and_then(|t| t.as_str()).map(|s| s.to_string())
136                                            .or_else(|| c0.get("text").and_then(|t| t.as_str()).map(|s| s.to_string())));
137                                    if let Some(text) = maybe { if !text.is_empty() { yield text; } }
138                                }
139                            }
140                        } else { break; }
141                    }
142                }
143        };
144        Self {
145            stream: tokio::sync::Mutex::new(Box::pin(s)),
146        }
147    }
148}
149
150#[async_trait]
151impl StreamItem for SseStreamItem {
152    async fn next(&mut self) -> Result<Option<String>, AgentError> {
153        let mut guard = self.stream.lock().await;
154        match guard.next().await {
155            Some(Ok(s)) => Ok(Some(s)),
156            Some(Err(e)) => Err(e),
157            None => Ok(None),
158        }
159    }
160}