Skip to main content

cortex_llm/
openai_compat.rs

1//! HTTP adapter that posts to any OpenAI-compatible `/v1/chat/completions`
2//! endpoint.
3//!
4//! [`OpenAiCompatAdapter`] implements [`LlmAdapter`] and works with any server
5//! that speaks the OpenAI chat-completions wire format, including LM Studio
6//! (default `http://localhost:1234`), LocalAI, Ollama in OpenAI-proxy mode
7//! (`http://localhost:11434`), vLLM, and the hosted OpenAI API itself.
8//!
9//! ## Key differences from [`crate::ollama_http::OllamaHttpAdapter`]
10//!
11//! - No loopback enforcement — `base_url` may be any valid HTTP(S) URL.
12//! - No digest-pin requirement — local model names are arbitrary strings;
13//!   any non-empty model identifier is accepted.
14//! - `api_key` is optional; when absent or empty, the `Authorization` header
15//!   is omitted entirely (appropriate for local servers).
16//!
17//! ## Runtime ceiling
18//!
19//! Determined at construction from `base_url`:
20//!
21//! - Loopback host (`localhost`, `127.x.x.x`, `::1`) → [`RuntimeCeiling::LocalUnsigned`].
22//! - Any other host → [`RuntimeCeiling::RemoteUnsigned`].
23//!
24//! ## SSE streaming
25//!
26//! [`OpenAiCompatAdapter::stream_boxed`] overrides the trait default and reads
27//! the response body as Server-Sent Events. Each `data:` line is parsed as an
28//! OpenAI streaming chunk; `data: [DONE]` terminates the stream.
29
30use std::net::IpAddr;
31use std::time::Duration;
32
33use async_trait::async_trait;
34use serde::{Deserialize, Serialize};
35
36use crate::adapter::{
37    blake3_hex, BoxStream, LlmAdapter, LlmError, LlmRequest, LlmResponse, LlmRole, StreamChunk,
38    TokenUsage,
39};
40use crate::sensitivity::{check_remote_prompt_sensitivity, MaxSensitivity};
41
42// ---------------------------------------------------------------------------
43// Runtime ceiling
44// ---------------------------------------------------------------------------
45
46/// Runtime ceiling derived from the `base_url` at adapter construction time.
47///
48/// Used by `cortex run` to gate persistence policy decisions (ADR 0037
49/// weakest-link): a loopback endpoint carries `LocalUnsigned`; any other
50/// endpoint carries `RemoteUnsigned`.
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum RuntimeCeiling {
53    /// Endpoint host is a loopback address (`localhost`, `127.x.x.x`, `::1`).
54    LocalUnsigned,
55    /// Endpoint host is non-loopback.
56    RemoteUnsigned,
57}
58
59// ---------------------------------------------------------------------------
60// Adapter struct
61// ---------------------------------------------------------------------------
62
63/// HTTP adapter for any OpenAI-compatible `/v1/chat/completions` server.
64///
65/// Construct via [`OpenAiCompatAdapter::new`]. The adapter is `Send + Sync`
66/// and may be held behind an `Arc<dyn LlmAdapter>`.
67#[derive(Debug, Clone)]
68pub struct OpenAiCompatAdapter {
69    /// Base URL, e.g. `http://localhost:1234`. No trailing slash.
70    base_url: String,
71    /// Model identifier passed verbatim in the JSON body.
72    model: String,
73    /// Optional API key; `None` means no `Authorization` header is sent.
74    api_key: Option<String>,
75    /// Per-call HTTP timeout.
76    timeout_ms: u64,
77    /// ADR 0037 runtime ceiling computed from `base_url` at construction.
78    ceiling: RuntimeCeiling,
79    /// Maximum data-classification level permitted in remote prompts.
80    /// Defaults to [`MaxSensitivity::Medium`] when constructed via [`Self::new`].
81    max_sensitivity: MaxSensitivity,
82}
83
84impl OpenAiCompatAdapter {
85    /// Construct an adapter.
86    ///
87    /// `api_key` accepts `None` or an empty-string `Some("")`; both result in
88    /// no `Authorization` header being sent.
89    ///
90    /// `max_sensitivity` controls the data-classification gate before remote
91    /// dispatch. Pass `None` to use the default of [`MaxSensitivity::Medium`],
92    /// which blocks high-sensitivity memories from being sent to the endpoint.
93    ///
94    /// A warning is printed to stderr when `base_url` resolves to a non-loopback
95    /// host, because all prompt content will be sent to that remote server.
96    ///
97    /// Returns [`LlmError::InvalidRequest`] when:
98    /// - `base_url` does not start with `http://` or `https://`.
99    /// - `base_url` contains no host.
100    /// - `model` is empty.
101    pub fn new(
102        base_url: impl Into<String>,
103        model: impl Into<String>,
104        api_key: Option<String>,
105        timeout_ms: u64,
106        max_sensitivity: Option<MaxSensitivity>,
107    ) -> Result<Self, LlmError> {
108        let base_url = base_url.into();
109        let model = model.into();
110
111        if model.is_empty() {
112            return Err(LlmError::InvalidRequest(
113                "openai-compat: model must not be empty".to_string(),
114            ));
115        }
116
117        let ceiling = ceiling_for_url(&base_url)?;
118
119        // Warn when the endpoint is non-loopback: all prompt content is sent remotely.
120        if ceiling == RuntimeCeiling::RemoteUnsigned {
121            eprintln!(
122                "cortex: openai-compat: WARNING: endpoint {} is not loopback-only. \
123                 All prompt content will be sent to this remote server.",
124                base_url
125            );
126        }
127
128        // Normalise: treat empty string as absent key.
129        let api_key = api_key.filter(|k| !k.is_empty());
130
131        Ok(Self {
132            base_url,
133            model,
134            api_key,
135            timeout_ms,
136            ceiling,
137            max_sensitivity: max_sensitivity.unwrap_or(MaxSensitivity::Medium),
138        })
139    }
140
141    /// Return the runtime ceiling determined at construction from `base_url`.
142    #[must_use]
143    pub fn runtime_ceiling(&self) -> RuntimeCeiling {
144        self.ceiling
145    }
146}
147
148// ---------------------------------------------------------------------------
149// URL / ceiling helper
150// ---------------------------------------------------------------------------
151
152fn ceiling_for_url(base_url: &str) -> Result<RuntimeCeiling, LlmError> {
153    let rest = if let Some(r) = base_url.strip_prefix("http://") {
154        r
155    } else if let Some(r) = base_url.strip_prefix("https://") {
156        r
157    } else {
158        return Err(LlmError::InvalidRequest(format!(
159            "openai-compat: base_url must start with http:// or https://: {base_url}"
160        )));
161    };
162
163    let host = extract_host(rest).ok_or_else(|| {
164        LlmError::InvalidRequest(format!(
165            "openai-compat: base_url must contain a host: {base_url}"
166        ))
167    })?;
168
169    if is_loopback_host(host) {
170        Ok(RuntimeCeiling::LocalUnsigned)
171    } else {
172        Ok(RuntimeCeiling::RemoteUnsigned)
173    }
174}
175
176fn extract_host(rest: &str) -> Option<&str> {
177    // Strip userinfo if present (e.g. `user:pass@host:port/path`).
178    let authority = rest.split(['/', '?', '#']).next().unwrap_or_default();
179    if authority.is_empty() {
180        return None;
181    }
182
183    // IPv6 literal `[::1]:port`
184    if let Some(after_open) = authority.strip_prefix('[') {
185        let (host, suffix) = after_open.split_once(']')?;
186        if suffix.is_empty() || suffix.starts_with(':') {
187            return Some(host);
188        }
189        return None;
190    }
191
192    // Strip optional port.
193    let host = authority.split(':').next().unwrap_or_default();
194    if host.is_empty() {
195        None
196    } else {
197        Some(host)
198    }
199}
200
201fn is_loopback_host(host: &str) -> bool {
202    if host.eq_ignore_ascii_case("localhost") {
203        return true;
204    }
205    host.parse::<IpAddr>().is_ok_and(|ip| ip.is_loopback())
206}
207
208// ---------------------------------------------------------------------------
209// Wire types — non-streaming
210// ---------------------------------------------------------------------------
211
212/// Outgoing body for `POST /v1/chat/completions` (non-streaming).
213#[derive(Debug, Serialize)]
214struct ChatCompletionRequest<'a> {
215    model: &'a str,
216    messages: Vec<OpenAiMessage<'a>>,
217    stream: bool,
218    max_tokens: u32,
219}
220
221/// One message in the OpenAI chat format.
222#[derive(Debug, Serialize)]
223struct OpenAiMessage<'a> {
224    role: &'a str,
225    content: &'a str,
226}
227
228/// Top-level `/v1/chat/completions` response envelope.
229#[derive(Debug, Deserialize)]
230struct ChatCompletionResponse {
231    #[serde(default)]
232    choices: Vec<Choice>,
233    #[serde(default)]
234    usage: Option<OpenAiUsage>,
235}
236
237/// One element of the `choices` array.
238#[derive(Debug, Deserialize)]
239struct Choice {
240    #[serde(default)]
241    message: ChoiceMessage,
242}
243
244/// The `message` field inside a non-streaming choice.
245#[derive(Debug, Default, Deserialize)]
246struct ChoiceMessage {
247    #[serde(default)]
248    content: String,
249}
250
251/// Token usage reported by the provider (optional).
252#[derive(Debug, Deserialize)]
253struct OpenAiUsage {
254    #[serde(default)]
255    prompt_tokens: u32,
256    #[serde(default)]
257    completion_tokens: u32,
258}
259
260// ---------------------------------------------------------------------------
261// Wire types — SSE streaming
262// ---------------------------------------------------------------------------
263
264/// One `data:` line from an OpenAI SSE stream.
265#[derive(Debug, Deserialize)]
266struct StreamChunkEnvelope {
267    #[serde(default)]
268    choices: Vec<StreamChoice>,
269}
270
271/// One element of `choices` in a streaming delta.
272#[derive(Debug, Default, Deserialize)]
273struct StreamChoice {
274    #[serde(default)]
275    delta: StreamDelta,
276    finish_reason: Option<String>,
277}
278
279/// The `delta` field inside a streaming choice.
280#[derive(Debug, Default, Deserialize)]
281struct StreamDelta {
282    #[serde(default)]
283    content: String,
284}
285
286// ---------------------------------------------------------------------------
287// LlmAdapter implementation
288// ---------------------------------------------------------------------------
289
290#[async_trait]
291impl LlmAdapter for OpenAiCompatAdapter {
292    fn adapter_id(&self) -> &'static str {
293        "openai-compat"
294    }
295
296    async fn complete(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
297        // Sensitivity gate: reject prompts containing high-sensitivity markers
298        // before any bytes leave the machine.
299        let prompt_text: String = std::iter::once(req.system.as_str())
300            .chain(req.messages.iter().map(|m| m.content.as_str()))
301            .collect::<Vec<_>>()
302            .join("\n");
303        check_remote_prompt_sensitivity(&prompt_text, self.max_sensitivity)?;
304
305        let base_url = self.base_url.clone();
306        let model = self.model.clone();
307        let api_key = self.api_key.clone();
308        let timeout_ms = self.timeout_ms;
309
310        let result = tokio::task::spawn_blocking(move || {
311            call_openai_compat(&base_url, &model, api_key.as_deref(), &req, timeout_ms)
312        })
313        .await
314        .map_err(|e| LlmError::Transport(format!("spawn_blocking join error: {e}")))?;
315
316        result
317    }
318
319    /// Override with true OpenAI SSE streaming via `POST /v1/chat/completions`
320    /// with `"stream": true`.
321    ///
322    /// Parses `data:` lines, extracts `choices[0].delta.content`, and skips
323    /// the terminal `data: [DONE]` sentinel.
324    fn stream_boxed(&self, req: LlmRequest) -> BoxStream<'_> {
325        stream_openai_compat_sse(
326            self.base_url.clone(),
327            self.model.clone(),
328            self.api_key.clone(),
329            req,
330        )
331    }
332}
333
334// ---------------------------------------------------------------------------
335// Synchronous HTTP call (non-streaming)
336// ---------------------------------------------------------------------------
337
338fn call_openai_compat(
339    base_url: &str,
340    model: &str,
341    api_key: Option<&str>,
342    req: &LlmRequest,
343    timeout_ms: u64,
344) -> Result<LlmResponse, LlmError> {
345    let url = format!("{base_url}/v1/chat/completions");
346
347    let messages: Vec<OpenAiMessage<'_>> = req
348        .messages
349        .iter()
350        .map(|m| OpenAiMessage {
351            role: role_to_str(m.role),
352            content: &m.content,
353        })
354        .collect();
355
356    let body = ChatCompletionRequest {
357        model,
358        messages,
359        stream: false,
360        max_tokens: req.max_tokens,
361    };
362
363    let body_value = serde_json::to_value(&body)
364        .map_err(|e| LlmError::Transport(format!("request serialization failed: {e}")))?;
365
366    let timeout = Duration::from_millis(timeout_ms);
367    let agent = ureq::AgentBuilder::new().timeout(timeout).build();
368
369    let mut request = agent.post(&url).set("content-type", "application/json");
370    if let Some(key) = api_key {
371        request = request.set("authorization", &format!("Bearer {key}"));
372    }
373
374    let raw_response = request
375        .send_json(body_value)
376        .map_err(|err| map_ureq_error(err, timeout_ms))?;
377
378    let status = raw_response.status();
379    if status != 200 {
380        return Err(LlmError::Upstream(format!("HTTP {status}")));
381    }
382
383    let response_text = raw_response
384        .into_string()
385        .map_err(|e| LlmError::Transport(format!("reading response body: {e}")))?;
386
387    let parsed: ChatCompletionResponse = serde_json::from_str(&response_text)
388        .map_err(|e| LlmError::Parse(format!("openai-compat response parse: {e}")))?;
389
390    let text = parsed
391        .choices
392        .into_iter()
393        .next()
394        .map(|c| c.message.content)
395        .ok_or_else(|| {
396            LlmError::Parse("openai-compat response contained no choices".to_string())
397        })?;
398
399    let raw_hash = blake3_hex(response_text.as_bytes());
400    let usage = parsed.usage.map(|u| TokenUsage {
401        prompt_tokens: u.prompt_tokens,
402        completion_tokens: u.completion_tokens,
403    });
404
405    Ok(LlmResponse {
406        text,
407        parsed_json: None,
408        model: model.to_string(),
409        usage,
410        raw_hash,
411    })
412}
413
414// ---------------------------------------------------------------------------
415// SSE streaming implementation
416// ---------------------------------------------------------------------------
417
418/// Build a [`BoxStream`] that drives OpenAI-compatible SSE streaming.
419///
420/// Extracted as a free function so the `async_stream::stream!` macro is not
421/// nested inside an `impl` block, which can confuse lifetime inference.
422fn stream_openai_compat_sse(
423    base_url: String,
424    model: String,
425    api_key: Option<String>,
426    req: LlmRequest,
427) -> BoxStream<'static> {
428    Box::pin(async_stream::stream! {
429        let timeout_ms = req.timeout_ms;
430        let result = tokio::task::spawn_blocking(move || {
431            call_openai_compat_streaming(&base_url, &model, api_key.as_deref(), &req, timeout_ms)
432        })
433        .await;
434
435        match result {
436            Ok(chunks) => {
437                for chunk in chunks {
438                    yield chunk;
439                }
440            }
441            Err(e) => yield Err(LlmError::Transport(format!("spawn_blocking join error: {e}"))),
442        }
443    })
444}
445
446/// Synchronous OpenAI-compatible SSE streaming call, executed inside
447/// `spawn_blocking`.
448///
449/// Posts to `/v1/chat/completions` with `stream: true`, then reads the
450/// response body line by line. SSE protocol:
451/// - Empty lines are event separators — skip them.
452/// - Lines beginning with `event:` are event-type hints — skip them.
453/// - `data: [DONE]` is the terminal sentinel — stop processing.
454/// - Lines beginning with `data:` carry the JSON delta payload.
455fn call_openai_compat_streaming(
456    base_url: &str,
457    model: &str,
458    api_key: Option<&str>,
459    req: &LlmRequest,
460    timeout_ms: u64,
461) -> Vec<Result<StreamChunk, LlmError>> {
462    let url = format!("{base_url}/v1/chat/completions");
463
464    let messages: Vec<OpenAiMessage<'_>> = req
465        .messages
466        .iter()
467        .map(|m| OpenAiMessage {
468            role: role_to_str(m.role),
469            content: &m.content,
470        })
471        .collect();
472
473    let body = ChatCompletionRequest {
474        model,
475        messages,
476        stream: true,
477        max_tokens: req.max_tokens,
478    };
479
480    let body_value = match serde_json::to_value(&body) {
481        Ok(v) => v,
482        Err(e) => {
483            return vec![Err(LlmError::Transport(format!(
484                "request serialization failed: {e}"
485            )))]
486        }
487    };
488
489    let timeout = Duration::from_millis(timeout_ms);
490    let agent = ureq::AgentBuilder::new().timeout(timeout).build();
491
492    let mut request = agent.post(&url).set("content-type", "application/json");
493    if let Some(key) = api_key {
494        request = request.set("authorization", &format!("Bearer {key}"));
495    }
496
497    let raw_response = match request.send_json(body_value) {
498        Ok(r) => r,
499        Err(err) => return vec![Err(map_ureq_error(err, timeout_ms))],
500    };
501
502    let status = raw_response.status();
503    if status != 200 {
504        return vec![Err(LlmError::Upstream(format!("HTTP {status}")))];
505    }
506
507    let body_text = match raw_response.into_string() {
508        Ok(s) => s,
509        Err(e) => {
510            return vec![Err(LlmError::Transport(format!(
511                "reading streaming response body: {e}"
512            )))]
513        }
514    };
515
516    let mut chunks = Vec::new();
517
518    for line in body_text.lines() {
519        if line.is_empty() || line.starts_with("event:") {
520            continue;
521        }
522
523        let data = match line.strip_prefix("data:") {
524            Some(rest) => rest.trim(),
525            None => continue,
526        };
527
528        // Terminal sentinel — no further lines need processing.
529        if data == "[DONE]" {
530            chunks.push(Ok(StreamChunk {
531                delta: String::new(),
532                finish_reason: Some("stop".into()),
533            }));
534            return chunks;
535        }
536
537        let envelope: StreamChunkEnvelope = match serde_json::from_str(data) {
538            Ok(v) => v,
539            Err(e) => {
540                chunks.push(Err(LlmError::Parse(format!(
541                    "openai-compat SSE data parse: {e}: {data}"
542                ))));
543                continue;
544            }
545        };
546
547        let choice = match envelope.choices.into_iter().next() {
548            Some(c) => c,
549            None => continue,
550        };
551
552        let finish_reason = choice.finish_reason;
553        let delta_text = choice.delta.content;
554
555        chunks.push(Ok(StreamChunk {
556            delta: delta_text,
557            finish_reason,
558        }));
559    }
560
561    chunks
562}
563
564// ---------------------------------------------------------------------------
565// ureq error mapping
566// ---------------------------------------------------------------------------
567
568fn map_ureq_error(err: ureq::Error, timeout_ms: u64) -> LlmError {
569    match err {
570        ureq::Error::Transport(t) => {
571            let msg = t.to_string();
572            if is_timeout_message(&msg) {
573                LlmError::Timeout { timeout_ms }
574            } else {
575                LlmError::Transport(msg)
576            }
577        }
578        ureq::Error::Status(code, _) => LlmError::Upstream(format!("HTTP {code}")),
579    }
580}
581
582fn is_timeout_message(msg: &str) -> bool {
583    let lower = msg.to_ascii_lowercase();
584    lower.contains("timed out") || lower.contains("deadline exceeded") || lower.contains("timeout")
585}
586
587// ---------------------------------------------------------------------------
588// Role serialization helper
589// ---------------------------------------------------------------------------
590
591/// Return the lowercase role string used by the OpenAI chat-completions API.
592fn role_to_str(role: LlmRole) -> &'static str {
593    match role {
594        LlmRole::User => "user",
595        LlmRole::Assistant => "assistant",
596        LlmRole::Tool => "tool",
597    }
598}