openai_agents_rust/
realtime.rs1use async_trait::async_trait;
2use bytes::Bytes;
3use futures_util::StreamExt;
4use reqwest::Client;
5use serde_json::Value;
6use std::pin::Pin;
7
8use crate::config::Config;
9use crate::error::AgentError;
10
11#[async_trait]
13pub trait Realtime: Send + Sync {
14 async fn start_stream(&self) -> Result<Box<dyn StreamItem>, AgentError>;
16}
17
18#[async_trait]
20pub trait StreamItem: Send + Sync {
21 async fn next(&mut self) -> Result<Option<String>, AgentError>;
23}
24
25pub struct OpenAiChatRealtime {
28 client: Client,
29 base_url: String,
30 auth_token: Option<String>,
31 model: String,
32 messages: Vec<Value>,
33 max_tokens: Option<i32>,
35 temperature: Option<f32>,
36}
37
38impl OpenAiChatRealtime {
39 pub fn new_with_messages(config: Config, messages: Vec<Value>) -> Self {
40 let client = Client::builder()
41 .user_agent("openai-agents-rust")
42 .build()
43 .expect("Failed to build reqwest client");
44 let auth_token = if config.api_key.is_empty() {
45 None
46 } else {
47 Some(config.api_key.clone())
48 };
49 Self {
50 client,
51 base_url: config.base_url.clone(),
52 auth_token,
53 model: config.model.clone(),
54 messages,
55 max_tokens: Some(512),
56 temperature: Some(0.2),
57 }
58 }
59
60 pub fn new_simple(config: Config, prompt: &str) -> Self {
61 let messages = vec![serde_json::json!({"role":"user","content":prompt})];
62 Self::new_with_messages(config, messages)
63 }
64
65 fn url(&self) -> String {
66 format!("{}/chat/completions", self.base_url.trim_end_matches('/'))
67 }
68}
69
70#[async_trait]
71impl Realtime for OpenAiChatRealtime {
72 async fn start_stream(&self) -> Result<Box<dyn StreamItem>, AgentError> {
73 let mut body = serde_json::json!({
74 "model": self.model,
75 "messages": self.messages,
76 "stream": true,
77 });
78 if let Some(mt) = self.max_tokens {
79 body["max_tokens"] = serde_json::json!(mt);
80 }
81 if let Some(t) = self.temperature {
82 body["temperature"] = serde_json::json!(t);
83 }
84
85 let mut req = self.client.post(self.url());
86 if let Some(token) = &self.auth_token {
87 req = req.bearer_auth(token);
88 }
89 let resp = req.json(&body).send().await.map_err(AgentError::from)?;
90 let status = resp.status();
91 if !status.is_success() {
92 let text = resp.text().await.unwrap_or_default();
93 return Err(AgentError::Other(format!(
94 "realtime stream failed: HTTP {} — {}",
95 status, text
96 )));
97 }
98
99 let item = SseStreamItem::new(resp);
100 Ok(Box::new(item))
101 }
102}
103
104struct SseStreamItem {
106 stream: tokio::sync::Mutex<
107 Pin<Box<dyn futures_core::Stream<Item = Result<String, AgentError>> + Send>>,
108 >,
109}
110
111impl SseStreamItem {
112 fn new(resp: reqwest::Response) -> Self {
113 let byte_stream = resp.bytes_stream();
114 let s = async_stream::try_stream! {
115 let mut buf: Vec<u8> = Vec::new();
116 futures_util::pin_mut!(byte_stream);
117 while let Some(chunk) = byte_stream.next().await {
118 let chunk: Bytes = chunk.map_err(AgentError::from)?;
119 buf.extend_from_slice(&chunk);
120 loop {
122 if let Some(pos) = buf.iter().position(|b| *b == b'\n') {
123 let line = buf.drain(..=pos).collect::<Vec<u8>>();
124 let line = String::from_utf8_lossy(&line).to_string();
125 let line = line.trim();
126 if line.is_empty() { continue; }
127 if let Some(rest) = line.strip_prefix("data: ") {
128 let data = rest.trim();
129 if data == "[DONE]" { break; }
130 if let Ok(v) = serde_json::from_str::<Value>(data) {
132 let maybe = v
134 .get("choices").and_then(|c| c.as_array()).and_then(|arr| arr.get(0))
135 .and_then(|c0| c0.get("delta").and_then(|d| d.get("content")).and_then(|t| t.as_str()).map(|s| s.to_string())
136 .or_else(|| c0.get("text").and_then(|t| t.as_str()).map(|s| s.to_string())));
137 if let Some(text) = maybe { if !text.is_empty() { yield text; } }
138 }
139 }
140 } else { break; }
141 }
142 }
143 };
144 Self {
145 stream: tokio::sync::Mutex::new(Box::pin(s)),
146 }
147 }
148}
149
150#[async_trait]
151impl StreamItem for SseStreamItem {
152 async fn next(&mut self) -> Result<Option<String>, AgentError> {
153 let mut guard = self.stream.lock().await;
154 match guard.next().await {
155 Some(Ok(s)) => Ok(Some(s)),
156 Some(Err(e)) => Err(e),
157 None => Ok(None),
158 }
159 }
160}