1use std::net::IpAddr;
31use std::time::Duration;
32
33use async_trait::async_trait;
34use serde::{Deserialize, Serialize};
35
36use crate::adapter::{
37 blake3_hex, BoxStream, LlmAdapter, LlmError, LlmRequest, LlmResponse, LlmRole, StreamChunk,
38 TokenUsage,
39};
40use crate::sensitivity::{check_remote_prompt_sensitivity, MaxSensitivity};
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum RuntimeCeiling {
53 LocalUnsigned,
55 RemoteUnsigned,
57}
58
59#[derive(Debug, Clone)]
68pub struct OpenAiCompatAdapter {
69 base_url: String,
71 model: String,
73 api_key: Option<String>,
75 timeout_ms: u64,
77 ceiling: RuntimeCeiling,
79 max_sensitivity: MaxSensitivity,
82}
83
84impl OpenAiCompatAdapter {
85 pub fn new(
102 base_url: impl Into<String>,
103 model: impl Into<String>,
104 api_key: Option<String>,
105 timeout_ms: u64,
106 max_sensitivity: Option<MaxSensitivity>,
107 ) -> Result<Self, LlmError> {
108 let base_url = base_url.into();
109 let model = model.into();
110
111 if model.is_empty() {
112 return Err(LlmError::InvalidRequest(
113 "openai-compat: model must not be empty".to_string(),
114 ));
115 }
116
117 let ceiling = ceiling_for_url(&base_url)?;
118
119 if ceiling == RuntimeCeiling::RemoteUnsigned {
121 eprintln!(
122 "cortex: openai-compat: WARNING: endpoint {} is not loopback-only. \
123 All prompt content will be sent to this remote server.",
124 base_url
125 );
126 }
127
128 let api_key = api_key.filter(|k| !k.is_empty());
130
131 Ok(Self {
132 base_url,
133 model,
134 api_key,
135 timeout_ms,
136 ceiling,
137 max_sensitivity: max_sensitivity.unwrap_or(MaxSensitivity::Medium),
138 })
139 }
140
141 #[must_use]
143 pub fn runtime_ceiling(&self) -> RuntimeCeiling {
144 self.ceiling
145 }
146}
147
148fn ceiling_for_url(base_url: &str) -> Result<RuntimeCeiling, LlmError> {
153 let rest = if let Some(r) = base_url.strip_prefix("http://") {
154 r
155 } else if let Some(r) = base_url.strip_prefix("https://") {
156 r
157 } else {
158 return Err(LlmError::InvalidRequest(format!(
159 "openai-compat: base_url must start with http:// or https://: {base_url}"
160 )));
161 };
162
163 let host = extract_host(rest).ok_or_else(|| {
164 LlmError::InvalidRequest(format!(
165 "openai-compat: base_url must contain a host: {base_url}"
166 ))
167 })?;
168
169 if is_loopback_host(host) {
170 Ok(RuntimeCeiling::LocalUnsigned)
171 } else {
172 Ok(RuntimeCeiling::RemoteUnsigned)
173 }
174}
175
176fn extract_host(rest: &str) -> Option<&str> {
177 let authority = rest.split(['/', '?', '#']).next().unwrap_or_default();
179 if authority.is_empty() {
180 return None;
181 }
182
183 if let Some(after_open) = authority.strip_prefix('[') {
185 let (host, suffix) = after_open.split_once(']')?;
186 if suffix.is_empty() || suffix.starts_with(':') {
187 return Some(host);
188 }
189 return None;
190 }
191
192 let host = authority.split(':').next().unwrap_or_default();
194 if host.is_empty() {
195 None
196 } else {
197 Some(host)
198 }
199}
200
201fn is_loopback_host(host: &str) -> bool {
202 if host.eq_ignore_ascii_case("localhost") {
203 return true;
204 }
205 host.parse::<IpAddr>().is_ok_and(|ip| ip.is_loopback())
206}
207
208#[derive(Debug, Serialize)]
214struct ChatCompletionRequest<'a> {
215 model: &'a str,
216 messages: Vec<OpenAiMessage<'a>>,
217 stream: bool,
218 max_tokens: u32,
219}
220
221#[derive(Debug, Serialize)]
223struct OpenAiMessage<'a> {
224 role: &'a str,
225 content: &'a str,
226}
227
228#[derive(Debug, Deserialize)]
230struct ChatCompletionResponse {
231 #[serde(default)]
232 choices: Vec<Choice>,
233 #[serde(default)]
234 usage: Option<OpenAiUsage>,
235}
236
237#[derive(Debug, Deserialize)]
239struct Choice {
240 #[serde(default)]
241 message: ChoiceMessage,
242}
243
244#[derive(Debug, Default, Deserialize)]
246struct ChoiceMessage {
247 #[serde(default)]
248 content: String,
249}
250
251#[derive(Debug, Deserialize)]
253struct OpenAiUsage {
254 #[serde(default)]
255 prompt_tokens: u32,
256 #[serde(default)]
257 completion_tokens: u32,
258}
259
260#[derive(Debug, Deserialize)]
266struct StreamChunkEnvelope {
267 #[serde(default)]
268 choices: Vec<StreamChoice>,
269}
270
271#[derive(Debug, Default, Deserialize)]
273struct StreamChoice {
274 #[serde(default)]
275 delta: StreamDelta,
276 finish_reason: Option<String>,
277}
278
279#[derive(Debug, Default, Deserialize)]
281struct StreamDelta {
282 #[serde(default)]
283 content: String,
284}
285
286#[async_trait]
291impl LlmAdapter for OpenAiCompatAdapter {
292 fn adapter_id(&self) -> &'static str {
293 "openai-compat"
294 }
295
296 async fn complete(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
297 let prompt_text: String = std::iter::once(req.system.as_str())
300 .chain(req.messages.iter().map(|m| m.content.as_str()))
301 .collect::<Vec<_>>()
302 .join("\n");
303 check_remote_prompt_sensitivity(&prompt_text, self.max_sensitivity)?;
304
305 let base_url = self.base_url.clone();
306 let model = self.model.clone();
307 let api_key = self.api_key.clone();
308 let timeout_ms = self.timeout_ms;
309
310 let result = tokio::task::spawn_blocking(move || {
311 call_openai_compat(&base_url, &model, api_key.as_deref(), &req, timeout_ms)
312 })
313 .await
314 .map_err(|e| LlmError::Transport(format!("spawn_blocking join error: {e}")))?;
315
316 result
317 }
318
319 fn stream_boxed(&self, req: LlmRequest) -> BoxStream<'_> {
325 stream_openai_compat_sse(
326 self.base_url.clone(),
327 self.model.clone(),
328 self.api_key.clone(),
329 req,
330 )
331 }
332}
333
334fn call_openai_compat(
339 base_url: &str,
340 model: &str,
341 api_key: Option<&str>,
342 req: &LlmRequest,
343 timeout_ms: u64,
344) -> Result<LlmResponse, LlmError> {
345 let url = format!("{base_url}/v1/chat/completions");
346
347 let messages: Vec<OpenAiMessage<'_>> = req
348 .messages
349 .iter()
350 .map(|m| OpenAiMessage {
351 role: role_to_str(m.role),
352 content: &m.content,
353 })
354 .collect();
355
356 let body = ChatCompletionRequest {
357 model,
358 messages,
359 stream: false,
360 max_tokens: req.max_tokens,
361 };
362
363 let body_value = serde_json::to_value(&body)
364 .map_err(|e| LlmError::Transport(format!("request serialization failed: {e}")))?;
365
366 let timeout = Duration::from_millis(timeout_ms);
367 let agent = ureq::AgentBuilder::new().timeout(timeout).build();
368
369 let mut request = agent.post(&url).set("content-type", "application/json");
370 if let Some(key) = api_key {
371 request = request.set("authorization", &format!("Bearer {key}"));
372 }
373
374 let raw_response = request
375 .send_json(body_value)
376 .map_err(|err| map_ureq_error(err, timeout_ms))?;
377
378 let status = raw_response.status();
379 if status != 200 {
380 return Err(LlmError::Upstream(format!("HTTP {status}")));
381 }
382
383 let response_text = raw_response
384 .into_string()
385 .map_err(|e| LlmError::Transport(format!("reading response body: {e}")))?;
386
387 let parsed: ChatCompletionResponse = serde_json::from_str(&response_text)
388 .map_err(|e| LlmError::Parse(format!("openai-compat response parse: {e}")))?;
389
390 let text = parsed
391 .choices
392 .into_iter()
393 .next()
394 .map(|c| c.message.content)
395 .ok_or_else(|| {
396 LlmError::Parse("openai-compat response contained no choices".to_string())
397 })?;
398
399 let raw_hash = blake3_hex(response_text.as_bytes());
400 let usage = parsed.usage.map(|u| TokenUsage {
401 prompt_tokens: u.prompt_tokens,
402 completion_tokens: u.completion_tokens,
403 });
404
405 Ok(LlmResponse {
406 text,
407 parsed_json: None,
408 model: model.to_string(),
409 usage,
410 raw_hash,
411 })
412}
413
414fn stream_openai_compat_sse(
423 base_url: String,
424 model: String,
425 api_key: Option<String>,
426 req: LlmRequest,
427) -> BoxStream<'static> {
428 Box::pin(async_stream::stream! {
429 let timeout_ms = req.timeout_ms;
430 let result = tokio::task::spawn_blocking(move || {
431 call_openai_compat_streaming(&base_url, &model, api_key.as_deref(), &req, timeout_ms)
432 })
433 .await;
434
435 match result {
436 Ok(chunks) => {
437 for chunk in chunks {
438 yield chunk;
439 }
440 }
441 Err(e) => yield Err(LlmError::Transport(format!("spawn_blocking join error: {e}"))),
442 }
443 })
444}
445
446fn call_openai_compat_streaming(
456 base_url: &str,
457 model: &str,
458 api_key: Option<&str>,
459 req: &LlmRequest,
460 timeout_ms: u64,
461) -> Vec<Result<StreamChunk, LlmError>> {
462 let url = format!("{base_url}/v1/chat/completions");
463
464 let messages: Vec<OpenAiMessage<'_>> = req
465 .messages
466 .iter()
467 .map(|m| OpenAiMessage {
468 role: role_to_str(m.role),
469 content: &m.content,
470 })
471 .collect();
472
473 let body = ChatCompletionRequest {
474 model,
475 messages,
476 stream: true,
477 max_tokens: req.max_tokens,
478 };
479
480 let body_value = match serde_json::to_value(&body) {
481 Ok(v) => v,
482 Err(e) => {
483 return vec![Err(LlmError::Transport(format!(
484 "request serialization failed: {e}"
485 )))]
486 }
487 };
488
489 let timeout = Duration::from_millis(timeout_ms);
490 let agent = ureq::AgentBuilder::new().timeout(timeout).build();
491
492 let mut request = agent.post(&url).set("content-type", "application/json");
493 if let Some(key) = api_key {
494 request = request.set("authorization", &format!("Bearer {key}"));
495 }
496
497 let raw_response = match request.send_json(body_value) {
498 Ok(r) => r,
499 Err(err) => return vec![Err(map_ureq_error(err, timeout_ms))],
500 };
501
502 let status = raw_response.status();
503 if status != 200 {
504 return vec![Err(LlmError::Upstream(format!("HTTP {status}")))];
505 }
506
507 let body_text = match raw_response.into_string() {
508 Ok(s) => s,
509 Err(e) => {
510 return vec![Err(LlmError::Transport(format!(
511 "reading streaming response body: {e}"
512 )))]
513 }
514 };
515
516 let mut chunks = Vec::new();
517
518 for line in body_text.lines() {
519 if line.is_empty() || line.starts_with("event:") {
520 continue;
521 }
522
523 let data = match line.strip_prefix("data:") {
524 Some(rest) => rest.trim(),
525 None => continue,
526 };
527
528 if data == "[DONE]" {
530 chunks.push(Ok(StreamChunk {
531 delta: String::new(),
532 finish_reason: Some("stop".into()),
533 }));
534 return chunks;
535 }
536
537 let envelope: StreamChunkEnvelope = match serde_json::from_str(data) {
538 Ok(v) => v,
539 Err(e) => {
540 chunks.push(Err(LlmError::Parse(format!(
541 "openai-compat SSE data parse: {e}: {data}"
542 ))));
543 continue;
544 }
545 };
546
547 let choice = match envelope.choices.into_iter().next() {
548 Some(c) => c,
549 None => continue,
550 };
551
552 let finish_reason = choice.finish_reason;
553 let delta_text = choice.delta.content;
554
555 chunks.push(Ok(StreamChunk {
556 delta: delta_text,
557 finish_reason,
558 }));
559 }
560
561 chunks
562}
563
564fn map_ureq_error(err: ureq::Error, timeout_ms: u64) -> LlmError {
569 match err {
570 ureq::Error::Transport(t) => {
571 let msg = t.to_string();
572 if is_timeout_message(&msg) {
573 LlmError::Timeout { timeout_ms }
574 } else {
575 LlmError::Transport(msg)
576 }
577 }
578 ureq::Error::Status(code, _) => LlmError::Upstream(format!("HTTP {code}")),
579 }
580}
581
582fn is_timeout_message(msg: &str) -> bool {
583 let lower = msg.to_ascii_lowercase();
584 lower.contains("timed out") || lower.contains("deadline exceeded") || lower.contains("timeout")
585}
586
587fn role_to_str(role: LlmRole) -> &'static str {
593 match role {
594 LlmRole::User => "user",
595 LlmRole::Assistant => "assistant",
596 LlmRole::Tool => "tool",
597 }
598}