1use std::collections::HashMap;
4
5use serde::Serialize;
6use crate::types::chat_api::ChatStreamChunk;
7use crate::types::response_api::{
8 ResponseReasoning, ResponseRequest, ResponseTextConfig, Tool, ToolChoice,
9};
10
11use super::super::util::{extract_queries_from_arguments, map_tool_name_to_output_type};
12
13#[derive(Debug, Clone)]
15pub struct StreamState {
16 pub response_id: String,
17 pub output_id: String,
18 pub content_index: u32,
19 pub full_text: String,
20 pub reasoning_text: String,
21 pub is_first_chunk: bool,
22 pub is_output_item_added: bool,
23 pub is_content_part_added: bool,
24 pub is_reasoning_added: bool,
25 pub is_function_call_item_added: bool,
26 pub is_completed: bool,
27 pub current_tool_calls: Vec<ToolCallState>,
28 pub completed_tool_calls: Vec<ToolCallState>,
29 pub model: String,
30 pub input_tokens: Option<i64>,
31 pub output_tokens: Option<i64>,
32 pub total_tokens: Option<i64>,
33 pub cached_tokens: Option<i64>,
34 pub reasoning_tokens: Option<i64>,
35 pub thinking_buffer: String,
37 pub is_thinking: bool,
39 pub next_output_index: u32,
41 pub text_output_index: Option<u32>,
43 pub reasoning_output_index: Option<u32>,
45 pub request_context: Option<ResponseRequestContext>,
47 pub final_status: String,
49 pub incomplete_reason: Option<String>,
51 pub refusal_text: String,
53}
54
55#[derive(Debug, Clone)]
56pub struct ToolCallState {
57 pub upstream_id: Option<String>,
58 pub id: String,
59 pub call_id: String,
60 pub item_type: String,
61 pub name: String,
62 pub arguments: String,
63 pub output_index: u32,
64 pub chat_api_index: u32,
65 pub last_args_len: usize,
66}
67
68#[derive(Debug, Clone, Serialize)]
69pub struct ResponseRequestContext {
70 pub instructions: Option<String>,
71 pub max_output_tokens: Option<u32>,
72 pub parallel_tool_calls: Option<bool>,
73 pub previous_response_id: Option<String>,
74 pub reasoning: Option<ResponseReasoning>,
75 pub store: Option<bool>,
76 pub temperature: Option<f32>,
77 pub text: Option<ResponseTextConfig>,
78 pub tool_choice: ToolChoice,
79 pub tools: Vec<Tool>,
80 pub top_p: Option<f32>,
81 pub truncation: Option<String>,
82 pub user: Option<String>,
83 pub metadata: Option<HashMap<String, serde_json::Value>>,
84}
85
86impl From<&ResponseRequest> for ResponseRequestContext {
87 fn from(req: &ResponseRequest) -> Self {
88 let mut metadata = req.metadata.clone().unwrap_or_default();
89 let tool_map: serde_json::Map<String, serde_json::Value> = req
90 .tools
91 .iter()
92 .filter_map(|t| {
93 t.name.as_ref().map(|name| {
94 (
95 name.clone(),
96 serde_json::json!({
97 "type": t.tool_type,
98 "strict": t.strict,
99 "extra": t.extra,
100 }),
101 )
102 })
103 })
104 .collect();
105 if !tool_map.is_empty() {
106 metadata.insert(
107 "x_proxy_tool_map".to_string(),
108 serde_json::Value::Object(tool_map),
109 );
110 }
111
112 Self {
113 instructions: req.instructions.clone(),
114 max_output_tokens: req.max_output_tokens.or(req.max_tokens),
115 parallel_tool_calls: req.parallel_tool_calls,
116 previous_response_id: req.previous_response_id.clone(),
117 reasoning: req.reasoning.clone(),
118 store: req.store,
119 temperature: req.temperature,
120 text: req.text.clone(),
121 tool_choice: req.tool_choice.clone(),
122 tools: req.tools.clone(),
123 top_p: req.top_p,
124 truncation: req.truncation.clone(),
125 user: req.user.clone(),
126 metadata: if metadata.is_empty() {
127 None
128 } else {
129 Some(metadata)
130 },
131 }
132 }
133}
134
135impl StreamState {
136 pub fn new(
138 response_id: String,
139 model: String,
140 request_context: Option<ResponseRequestContext>,
141 ) -> Self {
142 Self {
143 response_id: response_id.clone(),
144 output_id: format!("msg_{}", response_id),
145 content_index: 0,
146 full_text: String::new(),
147 reasoning_text: String::new(),
148 is_first_chunk: true,
149 is_output_item_added: false,
150 is_content_part_added: false,
151 is_reasoning_added: false,
152 is_function_call_item_added: false,
153 is_completed: false,
154 current_tool_calls: Vec::new(),
155 completed_tool_calls: Vec::new(),
156 model,
157 input_tokens: None,
158 output_tokens: None,
159 total_tokens: None,
160 cached_tokens: None,
161 reasoning_tokens: None,
162 thinking_buffer: String::new(),
163 is_thinking: false,
164 next_output_index: 0,
165 text_output_index: None,
166 reasoning_output_index: None,
167 request_context,
168 final_status: "completed".to_string(),
169 incomplete_reason: None,
170 refusal_text: String::new(),
171 }
172 }
173
174 pub fn update_usage(&mut self, chunk: &ChatStreamChunk) {
176 if let Some(usage) = &chunk.usage {
177 self.input_tokens = usage.prompt_tokens.map(|v| v as i64);
178 self.output_tokens = usage.completion_tokens.map(|v| v as i64);
179 self.total_tokens = usage.total_tokens.map(|v| v as i64);
180 self.cached_tokens = usage
181 .prompt_tokens_details
182 .as_ref()
183 .and_then(|d| d.cached_tokens)
184 .map(|v| v as i64);
185 self.reasoning_tokens = usage
186 .completion_tokens_details
187 .as_ref()
188 .and_then(|d| d.reasoning_tokens)
189 .map(|v| v as i64);
190 }
191 }
192
193 pub fn build_response_object(&self) -> Box<crate::types::response_api::ResponseObject> {
195 use crate::types::response_api::{
196 InputTokensDetails, OutputItemType, OutputTokensDetails, ResponseContentPart, ResponseObject,
197 ResponseOutputItem, ResponseTextConfig, ResponseTextFormat, Usage,
198 };
199 use chrono::Utc;
200
201 let mut output = Vec::new();
202
203 if self.is_reasoning_added && !self.reasoning_text.is_empty() {
205 output.push(ResponseOutputItem {
206 id: format!("reasoning_{}", self.response_id),
207 item_type: OutputItemType::Reasoning,
208 status: Some("completed".to_string()),
209 content: Some(vec![ResponseContentPart::OutputText {
210 text: self.reasoning_text.clone(),
211 annotations: vec![],
212 }]),
213 role: None,
214 name: None,
215 arguments: None,
216 call_id: None,
217 queries: None,
218 results: None,
219 });
220 }
221
222 if self.is_output_item_added && (!self.full_text.is_empty() || !self.refusal_text.is_empty()) {
224 let mut content_parts = Vec::new();
225 if !self.full_text.is_empty() {
226 content_parts.push(ResponseContentPart::OutputText {
227 text: self.full_text.clone(),
228 annotations: vec![],
229 });
230 }
231 if !self.refusal_text.is_empty() {
232 content_parts.push(ResponseContentPart::Refusal {
233 refusal: self.refusal_text.clone(),
234 });
235 }
236 output.push(ResponseOutputItem {
237 id: self.output_id.clone(),
238 item_type: OutputItemType::Message,
239 status: Some("completed".to_string()),
240 content: Some(content_parts),
241 role: Some("assistant".to_string()),
242 name: None,
243 arguments: None,
244 call_id: None,
245 queries: None,
246 results: None,
247 });
248 }
249
250 for tc in &self.completed_tool_calls {
252 let item_type = map_tool_name_to_output_type(&tc.name, self.request_context.as_ref().map(|ctx| &ctx.tools));
253 let (queries, results) = if item_type != OutputItemType::FunctionCall {
254 (extract_queries_from_arguments(&tc.arguments), Some(serde_json::Value::Null))
255 } else {
256 (None, None)
257 };
258 output.push(ResponseOutputItem {
259 id: tc.id.clone(),
260 item_type,
261 status: Some("completed".to_string()),
262 content: None,
263 role: None,
264 name: Some(tc.name.clone()),
265 arguments: Some(tc.arguments.clone()),
266 call_id: Some(tc.call_id.clone()),
267 queries,
268 results,
269 });
270 }
271
272 let usage = if self.input_tokens.is_some() || self.output_tokens.is_some() || self.total_tokens.is_some() {
273 Some(Usage {
274 input_tokens: self.input_tokens,
275 input_tokens_details: Some(InputTokensDetails {
276 cached_tokens: self.cached_tokens.unwrap_or(0),
277 }),
278 output_tokens: self.output_tokens,
279 output_tokens_details: Some(OutputTokensDetails {
280 reasoning_tokens: self.reasoning_tokens.unwrap_or(0),
281 }),
282 total_tokens: self.total_tokens,
283 })
284 } else {
285 None
286 };
287
288 Box::new(ResponseObject {
289 id: self.response_id.clone(),
290 object: "response".to_string(),
291 status: self.final_status.clone(),
292 model: self.model.clone(),
293 created_at: Utc::now().timestamp(),
294 completed_at: Some(Utc::now().timestamp()),
295 error: None,
296 incomplete_details: self
297 .incomplete_reason
298 .as_ref()
299 .map(|reason| serde_json::json!({ "reason": reason })),
300 instructions: self
301 .request_context
302 .as_ref()
303 .and_then(|ctx| ctx.instructions.clone()),
304 max_output_tokens: self
305 .request_context
306 .as_ref()
307 .and_then(|ctx| ctx.max_output_tokens),
308 max_tool_calls: None,
309 input: None,
310 output,
311 parallel_tool_calls: self
312 .request_context
313 .as_ref()
314 .and_then(|ctx| ctx.parallel_tool_calls),
315 previous_response_id: self
316 .request_context
317 .as_ref()
318 .and_then(|ctx| ctx.previous_response_id.clone()),
319 reasoning: self
320 .request_context
321 .as_ref()
322 .and_then(|ctx| ctx.reasoning.clone())
323 .or({
324 Some(crate::types::response_api::ResponseReasoning {
325 effort: None,
326 summary: None,
327 })
328 }),
329 store: self.request_context.as_ref().and_then(|ctx| ctx.store),
330 temperature: self
331 .request_context
332 .as_ref()
333 .and_then(|ctx| ctx.temperature),
334 text: self
335 .request_context
336 .as_ref()
337 .and_then(|ctx| ctx.text.clone())
338 .or_else(|| {
339 Some(ResponseTextConfig {
340 format: Some(ResponseTextFormat {
341 format_type: "text".to_string(),
342 name: None,
343 schema: None,
344 strict: None,
345 }),
346 })
347 }),
348 tool_choice: self
349 .request_context
350 .as_ref()
351 .map(|ctx| ctx.tool_choice.clone()),
352 tools: self
353 .request_context
354 .as_ref()
355 .map(|ctx| ctx.tools.clone()),
356 top_p: self.request_context.as_ref().and_then(|ctx| ctx.top_p),
357 truncation: self
358 .request_context
359 .as_ref()
360 .and_then(|ctx| ctx.truncation.clone()),
361 user: self
362 .request_context
363 .as_ref()
364 .and_then(|ctx| ctx.user.clone()),
365 metadata: self
366 .request_context
367 .as_ref()
368 .and_then(|ctx| ctx.metadata.clone()),
369 usage,
370 })
371 }
372}