Skip to main content

codex_convert_proxy/convert/streaming/
state.rs

1//! Streaming state types: StreamState, ToolCallState, and ResponseRequestContext.
2
3use std::collections::HashMap;
4
5use serde::Serialize;
6use crate::types::chat_api::ChatStreamChunk;
7use crate::types::response_api::{
8    ResponseReasoning, ResponseRequest, ResponseTextConfig, Tool, ToolChoice,
9};
10
11use super::super::util::{extract_queries_from_arguments, map_tool_name_to_output_type};
12
13/// Streaming converter state for tracking incremental changes.
14#[derive(Debug, Clone)]
15pub struct StreamState {
16    pub response_id: String,
17    pub output_id: String,
18    pub content_index: u32,
19    pub full_text: String,
20    pub reasoning_text: String,
21    pub is_first_chunk: bool,
22    pub is_output_item_added: bool,
23    pub is_content_part_added: bool,
24    pub is_reasoning_added: bool,
25    pub is_function_call_item_added: bool,
26    pub is_completed: bool,
27    pub current_tool_calls: Vec<ToolCallState>,
28    pub completed_tool_calls: Vec<ToolCallState>,
29    pub model: String,
30    pub input_tokens: Option<i64>,
31    pub output_tokens: Option<i64>,
32    pub total_tokens: Option<i64>,
33    pub cached_tokens: Option<i64>,
34    pub reasoning_tokens: Option<i64>,
35    /// Buffer for incomplete think/thought tags during streaming
36    pub thinking_buffer: String,
37    /// Whether we're currently inside a thinking tag
38    pub is_thinking: bool,
39    /// Next available output_index for sequential assignment
40    pub next_output_index: u32,
41    /// Stored output_index for text message items
42    pub text_output_index: Option<u32>,
43    /// Stored output_index for reasoning items
44    pub reasoning_output_index: Option<u32>,
45    /// Original Responses request fields for protocol-consistent events.
46    pub request_context: Option<ResponseRequestContext>,
47    /// Final response status derived from finish_reason.
48    pub final_status: String,
49    /// Optional incomplete reason when final_status is incomplete.
50    pub incomplete_reason: Option<String>,
51    /// Refusal text accumulated from streaming deltas.
52    pub refusal_text: String,
53}
54
55#[derive(Debug, Clone)]
56pub struct ToolCallState {
57    pub upstream_id: Option<String>,
58    pub id: String,
59    pub call_id: String,
60    pub item_type: String,
61    pub name: String,
62    pub arguments: String,
63    pub output_index: u32,
64    pub chat_api_index: u32,
65    pub last_args_len: usize,
66}
67
68#[derive(Debug, Clone, Serialize)]
69pub struct ResponseRequestContext {
70    pub instructions: Option<String>,
71    pub max_output_tokens: Option<u32>,
72    pub parallel_tool_calls: Option<bool>,
73    pub previous_response_id: Option<String>,
74    pub reasoning: Option<ResponseReasoning>,
75    pub store: Option<bool>,
76    pub temperature: Option<f32>,
77    pub text: Option<ResponseTextConfig>,
78    pub tool_choice: ToolChoice,
79    pub tools: Vec<Tool>,
80    pub top_p: Option<f32>,
81    pub truncation: Option<String>,
82    pub user: Option<String>,
83    pub metadata: Option<HashMap<String, serde_json::Value>>,
84}
85
86impl From<&ResponseRequest> for ResponseRequestContext {
87    fn from(req: &ResponseRequest) -> Self {
88        let mut metadata = req.metadata.clone().unwrap_or_default();
89        let tool_map: serde_json::Map<String, serde_json::Value> = req
90            .tools
91            .iter()
92            .filter_map(|t| {
93                t.name.as_ref().map(|name| {
94                    (
95                        name.clone(),
96                        serde_json::json!({
97                            "type": t.tool_type,
98                            "strict": t.strict,
99                            "extra": t.extra,
100                        }),
101                    )
102                })
103            })
104            .collect();
105        if !tool_map.is_empty() {
106            metadata.insert(
107                "x_proxy_tool_map".to_string(),
108                serde_json::Value::Object(tool_map),
109            );
110        }
111
112        Self {
113            instructions: req.instructions.clone(),
114            max_output_tokens: req.max_output_tokens.or(req.max_tokens),
115            parallel_tool_calls: req.parallel_tool_calls,
116            previous_response_id: req.previous_response_id.clone(),
117            reasoning: req.reasoning.clone(),
118            store: req.store,
119            temperature: req.temperature,
120            text: req.text.clone(),
121            tool_choice: req.tool_choice.clone(),
122            tools: req.tools.clone(),
123            top_p: req.top_p,
124            truncation: req.truncation.clone(),
125            user: req.user.clone(),
126            metadata: if metadata.is_empty() {
127                None
128            } else {
129                Some(metadata)
130            },
131        }
132    }
133}
134
135impl StreamState {
136    /// Create a new stream state.
137    pub fn new(
138        response_id: String,
139        model: String,
140        request_context: Option<ResponseRequestContext>,
141    ) -> Self {
142        Self {
143            response_id: response_id.clone(),
144            output_id: format!("msg_{}", response_id),
145            content_index: 0,
146            full_text: String::new(),
147            reasoning_text: String::new(),
148            is_first_chunk: true,
149            is_output_item_added: false,
150            is_content_part_added: false,
151            is_reasoning_added: false,
152            is_function_call_item_added: false,
153            is_completed: false,
154            current_tool_calls: Vec::new(),
155            completed_tool_calls: Vec::new(),
156            model,
157            input_tokens: None,
158            output_tokens: None,
159            total_tokens: None,
160            cached_tokens: None,
161            reasoning_tokens: None,
162            thinking_buffer: String::new(),
163            is_thinking: false,
164            next_output_index: 0,
165            text_output_index: None,
166            reasoning_output_index: None,
167            request_context,
168            final_status: "completed".to_string(),
169            incomplete_reason: None,
170            refusal_text: String::new(),
171        }
172    }
173
174    /// Update usage from a ChatStreamChunk.
175    pub fn update_usage(&mut self, chunk: &ChatStreamChunk) {
176        if let Some(usage) = &chunk.usage {
177            self.input_tokens = usage.prompt_tokens.map(|v| v as i64);
178            self.output_tokens = usage.completion_tokens.map(|v| v as i64);
179            self.total_tokens = usage.total_tokens.map(|v| v as i64);
180            self.cached_tokens = usage
181                .prompt_tokens_details
182                .as_ref()
183                .and_then(|d| d.cached_tokens)
184                .map(|v| v as i64);
185            self.reasoning_tokens = usage
186                .completion_tokens_details
187                .as_ref()
188                .and_then(|d| d.reasoning_tokens)
189                .map(|v| v as i64);
190        }
191    }
192
193    /// Build the final ResponseObject with all accumulated outputs.
194    pub fn build_response_object(&self) -> Box<crate::types::response_api::ResponseObject> {
195        use crate::types::response_api::{
196            InputTokensDetails, OutputItemType, OutputTokensDetails, ResponseContentPart, ResponseObject,
197            ResponseOutputItem, ResponseTextConfig, ResponseTextFormat, Usage,
198        };
199        use chrono::Utc;
200
201        let mut output = Vec::new();
202
203        // Add reasoning output if present
204        if self.is_reasoning_added && !self.reasoning_text.is_empty() {
205            output.push(ResponseOutputItem {
206                id: format!("reasoning_{}", self.response_id),
207                item_type: OutputItemType::Reasoning,
208                status: None,
209                content: Some(vec![]),
210                summary: Some(vec![crate::types::response_api::ReasoningSummaryPart::SummaryText {
211                    text: self.reasoning_text.clone(),
212                }]),
213                role: None,
214                name: None,
215                arguments: None,
216                call_id: None,
217                queries: None,
218                results: None,
219                namespace: None,
220            });
221        }
222
223        // Add assistant message output (text and/or refusal)
224        if self.is_output_item_added && (!self.full_text.is_empty() || !self.refusal_text.is_empty()) {
225            let mut content_parts = Vec::new();
226            if !self.full_text.is_empty() {
227                content_parts.push(ResponseContentPart::OutputText {
228                    text: self.full_text.clone(),
229                    annotations: vec![],
230                    logprobs: vec![],
231                });
232            }
233            if !self.refusal_text.is_empty() {
234                content_parts.push(ResponseContentPart::Refusal {
235                    refusal: self.refusal_text.clone(),
236                });
237            }
238            output.push(ResponseOutputItem {
239                id: self.output_id.clone(),
240                item_type: OutputItemType::Message,
241                status: Some("completed".to_string()),
242                content: Some(content_parts),
243                role: Some("assistant".to_string()),
244                name: None,
245                arguments: None,
246                call_id: None,
247                queries: None,
248                results: None,
249                summary: None,
250                namespace: None,
251            });
252        }
253
254        // Add function call outputs
255        for tc in &self.completed_tool_calls {
256            let item_type = map_tool_name_to_output_type(&tc.name, self.request_context.as_ref().map(|ctx| &ctx.tools));
257            let (queries, results) = if item_type != OutputItemType::FunctionCall {
258                (extract_queries_from_arguments(&tc.arguments), Some(serde_json::Value::Null))
259            } else {
260                (None, None)
261            };
262            output.push(ResponseOutputItem {
263                id: tc.id.clone(),
264                item_type,
265                status: Some("completed".to_string()),
266                content: None,
267                role: None,
268                name: Some(tc.name.clone()),
269                arguments: Some(tc.arguments.clone()),
270                call_id: Some(tc.call_id.clone()),
271                queries,
272                results,
273                summary: None,
274                namespace: None,
275            });
276        }
277
278        let usage = if self.input_tokens.is_some() || self.output_tokens.is_some() || self.total_tokens.is_some() {
279            Some(Usage {
280                input_tokens: self.input_tokens,
281                input_tokens_details: Some(InputTokensDetails {
282                    cached_tokens: self.cached_tokens.unwrap_or(0),
283                }),
284                output_tokens: self.output_tokens,
285                output_tokens_details: Some(OutputTokensDetails {
286                    reasoning_tokens: self.reasoning_tokens.unwrap_or(0),
287                }),
288                total_tokens: self.total_tokens,
289            })
290        } else {
291            None
292        };
293
294        Box::new(ResponseObject {
295            id: self.response_id.clone(),
296            object: "response".to_string(),
297            status: self.final_status.clone(),
298            model: self.model.clone(),
299            created_at: Utc::now().timestamp(),
300            completed_at: Some(Utc::now().timestamp()),
301            error: None,
302            incomplete_details: self
303                .incomplete_reason
304                .as_ref()
305                .map(|reason| serde_json::json!({ "reason": reason })),
306            background: None,
307            instructions: self
308                .request_context
309                .as_ref()
310                .and_then(|ctx| ctx.instructions.clone()),
311            max_output_tokens: self
312                .request_context
313                .as_ref()
314                .and_then(|ctx| ctx.max_output_tokens),
315            max_tool_calls: None,
316            input: None,
317            output,
318            parallel_tool_calls: self
319                .request_context
320                .as_ref()
321                .and_then(|ctx| ctx.parallel_tool_calls),
322            previous_response_id: self
323                .request_context
324                .as_ref()
325                .and_then(|ctx| ctx.previous_response_id.clone()),
326            reasoning: self
327                .request_context
328                .as_ref()
329                .and_then(|ctx| ctx.reasoning.clone()),
330            store: self.request_context.as_ref().and_then(|ctx| ctx.store),
331            temperature: self
332                .request_context
333                .as_ref()
334                .and_then(|ctx| ctx.temperature),
335            text: self
336                .request_context
337                .as_ref()
338                .and_then(|ctx| ctx.text.clone())
339                .or_else(|| {
340                    Some(ResponseTextConfig {
341                        format: Some(ResponseTextFormat {
342                            format_type: "text".to_string(),
343                            name: None,
344                            schema: None,
345                            strict: None,
346                        }),
347                    })
348                }),
349            tool_choice: self
350                .request_context
351                .as_ref()
352                .map(|ctx| ctx.tool_choice.clone()),
353            tools: self
354                .request_context
355                .as_ref()
356                .map(|ctx| ctx.tools.clone()),
357            top_p: self.request_context.as_ref().and_then(|ctx| ctx.top_p),
358            truncation: self
359                .request_context
360                .as_ref()
361                .and_then(|ctx| ctx.truncation.clone()),
362            user: self
363                .request_context
364                .as_ref()
365                .and_then(|ctx| ctx.user.clone()),
366            metadata: self
367                .request_context
368                .as_ref()
369                .and_then(|ctx| ctx.metadata.clone()),
370            service_tier: None,
371            top_logprobs: None,
372            usage,
373        })
374    }
375}