Skip to main content

sgr_agent/
oxide_client.rs

1//! OxideClient — LlmClient adapter for `openai-oxide` crate.
2//!
3//! Uses the **Responses API** (`POST /responses`) instead of Chat Completions.
4//! With `oxide-ws` feature: persistent WebSocket connection for -20-25% latency.
5//! Supports: structured output (json_schema), function calling, multi-turn (previous_response_id).
6
7use crate::client::LlmClient;
8use crate::tool::ToolDef;
9use crate::types::{LlmConfig, Message, Role, SgrError, ToolCall};
10use openai_oxide::OpenAI;
11use openai_oxide::config::ClientConfig;
12use openai_oxide::types::responses::*;
13use serde_json::Value;
14
15/// Record OTEL attributes on the current span for Phoenix/OpenInference.
16#[cfg(feature = "telemetry")]
17fn record_otel_usage(response: &Response, model: &str) {
18    use opentelemetry::trace::{Span, Tracer, TracerProvider};
19
20    let provider = opentelemetry::global::tracer_provider();
21    let tracer = provider.tracer("sgr-agent");
22    let mut otel_span = tracer.start("oxide.responses.api");
23
24    let pt = response
25        .usage
26        .as_ref()
27        .and_then(|u| u.input_tokens)
28        .unwrap_or(0);
29    let ct = response
30        .usage
31        .as_ref()
32        .and_then(|u| u.output_tokens)
33        .unwrap_or(0);
34
35    // OpenInference conventions (Phoenix)
36    otel_span.set_attribute(opentelemetry::KeyValue::new(
37        "openinference.span.kind",
38        "LLM",
39    ));
40    otel_span.set_attribute(opentelemetry::KeyValue::new(
41        "llm.model_name",
42        model.to_string(),
43    ));
44    otel_span.set_attribute(opentelemetry::KeyValue::new("llm.token_count.prompt", pt));
45    otel_span.set_attribute(opentelemetry::KeyValue::new(
46        "llm.token_count.completion",
47        ct,
48    ));
49    otel_span.set_attribute(opentelemetry::KeyValue::new(
50        "llm.token_count.total",
51        pt + ct,
52    ));
53
54    // GenAI conventions (LangSmith)
55    otel_span.set_attribute(opentelemetry::KeyValue::new("langsmith.span.kind", "LLM"));
56    otel_span.set_attribute(opentelemetry::KeyValue::new(
57        "gen_ai.request.model",
58        model.to_string(),
59    ));
60    otel_span.set_attribute(opentelemetry::KeyValue::new(
61        "gen_ai.response.model",
62        response.model.clone(),
63    ));
64    otel_span.set_attribute(opentelemetry::KeyValue::new(
65        "gen_ai.usage.prompt_tokens",
66        pt,
67    ));
68    otel_span.set_attribute(opentelemetry::KeyValue::new(
69        "gen_ai.usage.completion_tokens",
70        ct,
71    ));
72
73    // Output text
74    let output = response.output_text();
75    if !output.is_empty() {
76        otel_span.set_attribute(opentelemetry::KeyValue::new(
77            "gen_ai.completion.0.content",
78            if output.len() > 4000 {
79                format!("{}...", &output[..4000])
80            } else {
81                output
82            },
83        ));
84    }
85
86    otel_span.end();
87}
88
89#[cfg(not(feature = "telemetry"))]
90fn record_otel_usage(_response: &Response, _model: &str) {}
91
92/// LlmClient backed by openai-oxide (Responses API).
93///
94/// With `oxide-ws` feature: call `connect_ws()` to upgrade to WebSocket mode.
95/// All subsequent calls go over persistent wss:// connection (-20-25% latency).
96pub struct OxideClient {
97    client: OpenAI,
98    pub(crate) model: String,
99    pub(crate) temperature: Option<f64>,
100    pub(crate) max_tokens: Option<u32>,
101    /// Last response_id for multi-turn chaining.
102    last_response_id: std::sync::Mutex<Option<String>>,
103    /// WebSocket session (when oxide-ws feature is enabled and connected).
104    #[cfg(feature = "oxide-ws")]
105    ws: tokio::sync::Mutex<Option<openai_oxide::websocket::WsSession>>,
106    /// Lazy WS: true = connect on first request, false = HTTP only.
107    #[cfg(feature = "oxide-ws")]
108    ws_enabled: std::sync::atomic::AtomicBool,
109}
110
111impl OxideClient {
112    /// Create from LlmConfig.
113    pub fn from_config(config: &LlmConfig) -> Result<Self, SgrError> {
114        let api_key = config
115            .api_key
116            .clone()
117            .or_else(|| std::env::var("OPENAI_API_KEY").ok())
118            .unwrap_or_else(|| {
119                if config.base_url.is_some() {
120                    "dummy_key".into()
121                } else {
122                    "".into()
123                }
124            });
125
126        if api_key.is_empty() {
127            return Err(SgrError::Schema("No API key for oxide client".into()));
128        }
129
130        let mut client_config = ClientConfig::new(&api_key);
131        if let Some(ref url) = config.base_url {
132            client_config = client_config.base_url(url.clone());
133        }
134
135        Ok(Self {
136            client: OpenAI::with_config(client_config),
137            model: config.model.clone(),
138            temperature: Some(config.temp),
139            max_tokens: config.max_tokens,
140            last_response_id: std::sync::Mutex::new(None),
141            #[cfg(feature = "oxide-ws")]
142            ws: tokio::sync::Mutex::new(None),
143            #[cfg(feature = "oxide-ws")]
144            ws_enabled: std::sync::atomic::AtomicBool::new(false),
145        })
146    }
147
148    /// Enable WebSocket mode — lazy connect on first request.
149    ///
150    /// Does NOT open a connection immediately. The WS connection is established
151    /// on the first `send_request_auto()` call, eliminating idle timeout issues.
152    /// Falls back to HTTP automatically if WS fails.
153    ///
154    /// Requires `oxide-ws` feature.
155    #[cfg(feature = "oxide-ws")]
156    pub async fn connect_ws(&self) -> Result<(), SgrError> {
157        self.ws_enabled
158            .store(true, std::sync::atomic::Ordering::Relaxed);
159        tracing::info!(model = %self.model, "oxide WebSocket enabled (lazy connect)");
160        Ok(())
161    }
162
163    /// Send request — lazy WS connect + send, falls back to HTTP on any WS error.
164    async fn send_request_auto(
165        &self,
166        request: ResponseCreateRequest,
167    ) -> Result<Response, SgrError> {
168        #[cfg(feature = "oxide-ws")]
169        if self.ws_enabled.load(std::sync::atomic::Ordering::Relaxed) {
170            let mut ws_guard = self.ws.lock().await;
171
172            // Lazy connect
173            if ws_guard.is_none() {
174                match self.client.ws_session().await {
175                    Ok(session) => {
176                        tracing::info!(model = %self.model, "oxide WS connected (lazy)");
177                        *ws_guard = Some(session);
178                    }
179                    Err(e) => {
180                        tracing::warn!("oxide WS connect failed, using HTTP: {e}");
181                        self.ws_enabled
182                            .store(false, std::sync::atomic::Ordering::Relaxed);
183                    }
184                }
185            }
186
187            if let Some(ref mut session) = *ws_guard {
188                match session.send(request.clone()).await {
189                    Ok(response) => return Ok(response),
190                    Err(e) => {
191                        tracing::warn!("oxide WS send failed, falling back to HTTP: {e}");
192                        *ws_guard = None;
193                    }
194                }
195            }
196        }
197
198        // HTTP fallback
199        self.client
200            .responses()
201            .create(request)
202            .await
203            .map_err(|e| SgrError::Api {
204                status: 0,
205                body: e.to_string(),
206            })
207    }
208
209    /// Build request with mixed input: regular messages + function_call_output items.
210    /// Required when chaining with previous_response_id after a function call response.
211    fn build_request_with_tool_outputs(&self, messages: &[Message]) -> ResponseCreateRequest {
212        use openai_oxide::types::responses::ResponseInput;
213
214        let mut items: Vec<Value> = Vec::new();
215
216        for msg in messages {
217            match msg.role {
218                Role::Tool => {
219                    if let Some(ref call_id) = msg.tool_call_id {
220                        // Responses API function_call_output item
221                        items.push(serde_json::json!({
222                            "type": "function_call_output",
223                            "call_id": call_id,
224                            "output": msg.content
225                        }));
226                    }
227                }
228                Role::System => {
229                    items.push(serde_json::json!({
230                        "type": "message",
231                        "role": "system",
232                        "content": msg.content
233                    }));
234                }
235                Role::User => {
236                    items.push(serde_json::json!({
237                        "type": "message",
238                        "role": "user",
239                        "content": msg.content
240                    }));
241                }
242                Role::Assistant => {
243                    items.push(serde_json::json!({
244                        "type": "message",
245                        "role": "assistant",
246                        "content": msg.content
247                    }));
248                }
249            }
250        }
251
252        let mut req = ResponseCreateRequest::new(&self.model);
253        if !items.is_empty() {
254            req.input = Some(ResponseInput::Items(items));
255        }
256
257        // Temperature: send normally. openai-oxide WS layer auto-strips decimal values
258        // (OpenAI WS bug: https://community.openai.com/t/1375536).
259        if let Some(temp) = self.temperature
260            && (temp - 1.0).abs() > f64::EPSILON
261        {
262            req = req.temperature(temp);
263        }
264        if let Some(max) = self.max_tokens {
265            req = req.max_output_tokens(max as i64);
266        }
267
268        // Chain previous response if available
269        if let Some(prev_id) = self.last_response_id.lock().ok().and_then(|g| g.clone()) {
270            req = req.previous_response_id(prev_id);
271        }
272
273        req
274    }
275
276    /// Build a ResponseCreateRequest from messages + optional schema.
277    fn build_request(&self, messages: &[Message], schema: Option<&Value>) -> ResponseCreateRequest {
278        let mut input_items = Vec::new();
279
280        for msg in messages {
281            match msg.role {
282                Role::System => {
283                    input_items.push(ResponseInputItem {
284                        role: openai_oxide::types::common::Role::System,
285                        content: Value::String(msg.content.clone()),
286                    });
287                }
288                Role::User => {
289                    input_items.push(ResponseInputItem {
290                        role: openai_oxide::types::common::Role::User,
291                        content: Value::String(msg.content.clone()),
292                    });
293                }
294                Role::Assistant => {
295                    input_items.push(ResponseInputItem {
296                        role: openai_oxide::types::common::Role::Assistant,
297                        content: Value::String(msg.content.clone()),
298                    });
299                }
300                Role::Tool => {
301                    let tool_result = if let Some(ref id) = msg.tool_call_id {
302                        format!("[Tool result for {}]: {}", id, msg.content)
303                    } else {
304                        msg.content.clone()
305                    };
306                    input_items.push(ResponseInputItem {
307                        role: openai_oxide::types::common::Role::User,
308                        content: Value::String(tool_result),
309                    });
310                }
311            }
312        }
313
314        let mut req = ResponseCreateRequest::new(&self.model);
315
316        // Set input — prefer simple text when single user message (fewer tokens)
317        if input_items.len() == 1 && input_items[0].role == openai_oxide::types::common::Role::User
318        {
319            if let Some(text) = input_items[0].content.as_str() {
320                req = req.input(text);
321            } else {
322                req.input = Some(ResponseInput::Messages(input_items));
323            }
324        } else if !input_items.is_empty() {
325            req.input = Some(ResponseInput::Messages(input_items));
326        }
327
328        // Temperature — skip default to reduce payload
329        if let Some(temp) = self.temperature
330            && (temp - 1.0).abs() > f64::EPSILON
331        {
332            req = req.temperature(temp);
333        }
334
335        // Max tokens
336        if let Some(max) = self.max_tokens {
337            req = req.max_output_tokens(max as i64);
338        }
339
340        // Structured output via json_schema
341        if let Some(schema_val) = schema {
342            req = req.text(ResponseTextConfig {
343                format: Some(ResponseTextFormat::JsonSchema {
344                    name: "sgr_response".into(),
345                    description: None,
346                    schema: Some(schema_val.clone()),
347                    strict: Some(true),
348                }),
349                verbosity: None,
350            });
351        }
352
353        // Chain previous response if available
354        if let Some(prev_id) = self.last_response_id.lock().ok().and_then(|g| g.clone()) {
355            req = req.previous_response_id(prev_id);
356        }
357
358        req
359    }
360
361    /// Save response_id for multi-turn chaining.
362    fn save_response_id(&self, id: &str) {
363        if let Ok(mut guard) = self.last_response_id.lock() {
364            *guard = Some(id.to_string());
365        }
366    }
367
368    /// Set response_id externally (for stateful session coordination with coach).
369    pub fn set_response_id(&self, id: Option<&str>) {
370        if let Ok(mut guard) = self.last_response_id.lock() {
371            *guard = id.map(String::from);
372        }
373    }
374
375    /// Get current response_id.
376    pub fn response_id(&self) -> Option<String> {
377        self.last_response_id.lock().ok().and_then(|g| g.clone())
378    }
379
380    /// Function calling with explicit previous_response_id.
381    /// Returns tool calls + new response_id for chaining.
382    ///
383    /// Always sets `store(true)` so responses can be referenced by subsequent calls.
384    /// When `previous_response_id` is provided, only delta messages need to be sent
385    /// (server has full history from previous stored response).
386    ///
387    /// Tool messages (role=Tool with tool_call_id) are converted to Responses API
388    /// `function_call_output` items — required for chaining with previous_response_id.
389    pub async fn tools_call_stateful(
390        &self,
391        messages: &[Message],
392        tools: &[ToolDef],
393        previous_response_id: Option<&str>,
394    ) -> Result<(Vec<ToolCall>, Option<String>), SgrError> {
395        // Set external response_id for chaining
396        if let Some(pid) = previous_response_id {
397            self.set_response_id(Some(pid));
398        }
399
400        // Always use Items format (with "type":"message" on each item).
401        // HTTP API accepts Messages format (without type), but WS API requires it.
402        // Using Items consistently ensures both HTTP and WS work.
403        let mut req = self.build_request_with_tool_outputs(messages);
404        // Always store so next call can chain via previous_response_id
405        req = req.store(true);
406
407        // Convert ToolDefs to ResponseTools with strict mode.
408        // strict: true guarantees LLM output matches schema exactly (no parse errors).
409        // oxide ensure_strict() handles: additionalProperties, all-required,
410        // nullable→anyOf, allOf inlining, oneOf→anyOf.
411        let response_tools: Vec<ResponseTool> = tools
412            .iter()
413            .map(|t| {
414                let mut params = t.parameters.clone();
415                openai_oxide::parsing::ensure_strict(&mut params);
416                ResponseTool::Function {
417                    name: t.name.clone(),
418                    description: if t.description.is_empty() {
419                        None
420                    } else {
421                        Some(t.description.clone())
422                    },
423                    parameters: Some(params),
424                    strict: Some(true),
425                }
426            })
427            .collect();
428        req = req.tools(response_tools);
429
430        let response = self.send_request_auto(req).await?;
431
432        let response_id = response.id.clone();
433        self.save_response_id(&response_id);
434        record_otel_usage(&response, &self.model);
435
436        let input_tokens = response
437            .usage
438            .as_ref()
439            .and_then(|u| u.input_tokens)
440            .unwrap_or(0);
441        let cached_tokens = response
442            .usage
443            .as_ref()
444            .and_then(|u| u.input_tokens_details.as_ref())
445            .and_then(|d| d.cached_tokens)
446            .unwrap_or(0);
447
448        tracing::info!(
449            model = %response.model,
450            response_id = %response_id,
451            input_tokens,
452            cached_tokens,
453            chained = previous_response_id.is_some(),
454            "oxide.tools_call_stateful"
455        );
456
457        Ok((Self::extract_tool_calls(&response), Some(response_id)))
458    }
459
460    /// Extract tool calls from Responses API output items.
461    fn extract_tool_calls(response: &Response) -> Vec<ToolCall> {
462        response
463            .function_calls()
464            .into_iter()
465            .map(|fc| ToolCall {
466                id: fc.call_id,
467                name: fc.name,
468                arguments: fc.arguments,
469            })
470            .collect()
471    }
472}
473
474#[async_trait::async_trait]
475impl LlmClient for OxideClient {
476    async fn structured_call(
477        &self,
478        messages: &[Message],
479        schema: &Value,
480    ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
481        // Make schema OpenAI-strict (oxide handles nullable, allOf, etc.)
482        let mut strict_schema = schema.clone();
483        openai_oxide::parsing::ensure_strict(&mut strict_schema);
484
485        let req = self.build_request(messages, Some(&strict_schema));
486
487        let span = tracing::info_span!(
488            "oxide.responses.create",
489            model = %self.model,
490            method = "structured_call",
491        );
492        let _enter = span.enter();
493
494        let response = self.send_request_auto(req).await?;
495
496        self.save_response_id(&response.id);
497        record_otel_usage(&response, &self.model);
498
499        let raw_text = response.output_text();
500        let tool_calls = Self::extract_tool_calls(&response);
501        let parsed = serde_json::from_str::<Value>(&raw_text).ok();
502
503        tracing::info!(
504            model = %response.model,
505            response_id = %response.id,
506            input_tokens = response.usage.as_ref().and_then(|u| u.input_tokens).unwrap_or(0),
507            output_tokens = response.usage.as_ref().and_then(|u| u.output_tokens).unwrap_or(0),
508            "oxide.structured_call"
509        );
510
511        Ok((parsed, tool_calls, raw_text))
512    }
513
514    async fn tools_call(
515        &self,
516        messages: &[Message],
517        tools: &[ToolDef],
518    ) -> Result<Vec<ToolCall>, SgrError> {
519        // Use proper function_call_output format when messages contain tool results,
520        // required for Responses API with previous_response_id chaining
521        let has_tool_messages = messages.iter().any(|m| m.role == Role::Tool);
522        let mut req = if has_tool_messages {
523            self.build_request_with_tool_outputs(messages)
524        } else {
525            self.build_request(messages, None)
526        };
527
528        // Convert ToolDefs to ResponseTools — no strict mode (faster server-side)
529        let response_tools: Vec<ResponseTool> = tools
530            .iter()
531            .map(|t| ResponseTool::Function {
532                name: t.name.clone(),
533                description: if t.description.is_empty() {
534                    None
535                } else {
536                    Some(t.description.clone())
537                },
538                parameters: Some(t.parameters.clone()),
539                strict: None,
540            })
541            .collect();
542        req = req.tools(response_tools);
543
544        // Force model to always call a tool — prevents text-only responses
545        // that lose answer content (tools_call only returns Vec<ToolCall>).
546        req = req.tool_choice(openai_oxide::types::responses::ResponseToolChoice::Mode(
547            "required".into(),
548        ));
549
550        let response = self.send_request_auto(req).await?;
551
552        self.save_response_id(&response.id);
553        record_otel_usage(&response, &self.model);
554
555        tracing::info!(
556            model = %response.model,
557            response_id = %response.id,
558            "oxide.tools_call"
559        );
560
561        let calls = Self::extract_tool_calls(&response);
562        Ok(calls)
563    }
564
565    async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
566        let req = self.build_request(messages, None);
567
568        let response = self.send_request_auto(req).await?;
569
570        self.save_response_id(&response.id);
571        record_otel_usage(&response, &self.model);
572
573        let text = response.output_text();
574        if text.is_empty() {
575            return Err(SgrError::EmptyResponse);
576        }
577
578        tracing::info!(
579            model = %response.model,
580            response_id = %response.id,
581            input_tokens = response.usage.as_ref().and_then(|u| u.input_tokens).unwrap_or(0),
582            output_tokens = response.usage.as_ref().and_then(|u| u.output_tokens).unwrap_or(0),
583            "oxide.complete"
584        );
585
586        Ok(text)
587    }
588}
589
590#[cfg(test)]
591mod tests {
592    use super::*;
593
594    #[test]
595    fn oxide_client_from_config() {
596        // Just test construction doesn't panic
597        let config = LlmConfig::with_key("sk-test", "gpt-5.4");
598        let client = OxideClient::from_config(&config).unwrap();
599        assert_eq!(client.model, "gpt-5.4");
600    }
601
602    #[test]
603    fn build_request_simple() {
604        let config = LlmConfig::with_key("sk-test", "gpt-5.4").temperature(0.5);
605        let client = OxideClient::from_config(&config).unwrap();
606        let messages = vec![Message::system("Be helpful."), Message::user("Hello")];
607        let req = client.build_request(&messages, None);
608        assert_eq!(req.model, "gpt-5.4");
609        // System prompt goes as input message (not instructions) for fewer tokens
610        assert!(req.instructions.is_none());
611        assert!(req.input.is_some()); // system + user as messages
612        assert_eq!(req.temperature, Some(0.5));
613    }
614
615    #[test]
616    fn build_request_with_schema() {
617        let config = LlmConfig::with_key("sk-test", "gpt-5.4");
618        let client = OxideClient::from_config(&config).unwrap();
619        let schema = serde_json::json!({
620            "type": "object",
621            "properties": {"answer": {"type": "string"}},
622            "required": ["answer"]
623        });
624        let req = client.build_request(&[Message::user("Hi")], Some(&schema));
625        assert!(req.text.is_some());
626    }
627}