use std::net::IpAddr;
use std::time::Duration;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::adapter::{
blake3_hex, BoxStream, LlmAdapter, LlmError, LlmRequest, LlmResponse, LlmRole, StreamChunk,
TokenUsage,
};
use crate::sensitivity::{check_remote_prompt_sensitivity, MaxSensitivity};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RuntimeCeiling {
LocalUnsigned,
RemoteUnsigned,
}
#[derive(Debug, Clone)]
pub struct OpenAiCompatAdapter {
base_url: String,
model: String,
api_key: Option<String>,
timeout_ms: u64,
ceiling: RuntimeCeiling,
max_sensitivity: MaxSensitivity,
}
impl OpenAiCompatAdapter {
pub fn new(
base_url: impl Into<String>,
model: impl Into<String>,
api_key: Option<String>,
timeout_ms: u64,
max_sensitivity: Option<MaxSensitivity>,
) -> Result<Self, LlmError> {
let base_url = base_url.into();
let model = model.into();
if model.is_empty() {
return Err(LlmError::InvalidRequest(
"openai-compat: model must not be empty".to_string(),
));
}
let ceiling = ceiling_for_url(&base_url)?;
if ceiling == RuntimeCeiling::RemoteUnsigned {
eprintln!(
"cortex: openai-compat: WARNING: endpoint {} is not loopback-only. \
All prompt content will be sent to this remote server.",
base_url
);
}
let api_key = api_key.filter(|k| !k.is_empty());
Ok(Self {
base_url,
model,
api_key,
timeout_ms,
ceiling,
max_sensitivity: max_sensitivity.unwrap_or(MaxSensitivity::Medium),
})
}
#[must_use]
pub fn runtime_ceiling(&self) -> RuntimeCeiling {
self.ceiling
}
}
fn ceiling_for_url(base_url: &str) -> Result<RuntimeCeiling, LlmError> {
let rest = if let Some(r) = base_url.strip_prefix("http://") {
r
} else if let Some(r) = base_url.strip_prefix("https://") {
r
} else {
return Err(LlmError::InvalidRequest(format!(
"openai-compat: base_url must start with http:// or https://: {base_url}"
)));
};
let host = extract_host(rest).ok_or_else(|| {
LlmError::InvalidRequest(format!(
"openai-compat: base_url must contain a host: {base_url}"
))
})?;
if is_loopback_host(host) {
Ok(RuntimeCeiling::LocalUnsigned)
} else {
Ok(RuntimeCeiling::RemoteUnsigned)
}
}
fn extract_host(rest: &str) -> Option<&str> {
let authority = rest.split(['/', '?', '#']).next().unwrap_or_default();
if authority.is_empty() {
return None;
}
if let Some(after_open) = authority.strip_prefix('[') {
let (host, suffix) = after_open.split_once(']')?;
if suffix.is_empty() || suffix.starts_with(':') {
return Some(host);
}
return None;
}
let host = authority.split(':').next().unwrap_or_default();
if host.is_empty() {
None
} else {
Some(host)
}
}
fn is_loopback_host(host: &str) -> bool {
if host.eq_ignore_ascii_case("localhost") {
return true;
}
host.parse::<IpAddr>().is_ok_and(|ip| ip.is_loopback())
}
#[derive(Debug, Serialize)]
struct ChatCompletionRequest<'a> {
model: &'a str,
messages: Vec<OpenAiMessage<'a>>,
stream: bool,
max_tokens: u32,
}
#[derive(Debug, Serialize)]
struct OpenAiMessage<'a> {
role: &'a str,
content: &'a str,
}
#[derive(Debug, Deserialize)]
struct ChatCompletionResponse {
#[serde(default)]
choices: Vec<Choice>,
#[serde(default)]
usage: Option<OpenAiUsage>,
}
#[derive(Debug, Deserialize)]
struct Choice {
#[serde(default)]
message: ChoiceMessage,
}
#[derive(Debug, Default, Deserialize)]
struct ChoiceMessage {
#[serde(default)]
content: String,
}
#[derive(Debug, Deserialize)]
struct OpenAiUsage {
#[serde(default)]
prompt_tokens: u32,
#[serde(default)]
completion_tokens: u32,
}
#[derive(Debug, Deserialize)]
struct StreamChunkEnvelope {
#[serde(default)]
choices: Vec<StreamChoice>,
}
#[derive(Debug, Default, Deserialize)]
struct StreamChoice {
#[serde(default)]
delta: StreamDelta,
finish_reason: Option<String>,
}
#[derive(Debug, Default, Deserialize)]
struct StreamDelta {
#[serde(default)]
content: String,
}
#[async_trait]
impl LlmAdapter for OpenAiCompatAdapter {
fn adapter_id(&self) -> &'static str {
"openai-compat"
}
async fn complete(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
let prompt_text: String = std::iter::once(req.system.as_str())
.chain(req.messages.iter().map(|m| m.content.as_str()))
.collect::<Vec<_>>()
.join("\n");
check_remote_prompt_sensitivity(&prompt_text, self.max_sensitivity)?;
let base_url = self.base_url.clone();
let model = self.model.clone();
let api_key = self.api_key.clone();
let timeout_ms = self.timeout_ms;
let result = tokio::task::spawn_blocking(move || {
call_openai_compat(&base_url, &model, api_key.as_deref(), &req, timeout_ms)
})
.await
.map_err(|e| LlmError::Transport(format!("spawn_blocking join error: {e}")))?;
result
}
fn stream_boxed(&self, req: LlmRequest) -> BoxStream<'_> {
stream_openai_compat_sse(
self.base_url.clone(),
self.model.clone(),
self.api_key.clone(),
req,
)
}
}
fn call_openai_compat(
base_url: &str,
model: &str,
api_key: Option<&str>,
req: &LlmRequest,
timeout_ms: u64,
) -> Result<LlmResponse, LlmError> {
let url = format!("{base_url}/v1/chat/completions");
let messages: Vec<OpenAiMessage<'_>> = req
.messages
.iter()
.map(|m| OpenAiMessage {
role: role_to_str(m.role),
content: &m.content,
})
.collect();
let body = ChatCompletionRequest {
model,
messages,
stream: false,
max_tokens: req.max_tokens,
};
let body_value = serde_json::to_value(&body)
.map_err(|e| LlmError::Transport(format!("request serialization failed: {e}")))?;
let timeout = Duration::from_millis(timeout_ms);
let agent = ureq::AgentBuilder::new().timeout(timeout).build();
let mut request = agent.post(&url).set("content-type", "application/json");
if let Some(key) = api_key {
request = request.set("authorization", &format!("Bearer {key}"));
}
let raw_response = request
.send_json(body_value)
.map_err(|err| map_ureq_error(err, timeout_ms))?;
let status = raw_response.status();
if status != 200 {
return Err(LlmError::Upstream(format!("HTTP {status}")));
}
let response_text = raw_response
.into_string()
.map_err(|e| LlmError::Transport(format!("reading response body: {e}")))?;
let parsed: ChatCompletionResponse = serde_json::from_str(&response_text)
.map_err(|e| LlmError::Parse(format!("openai-compat response parse: {e}")))?;
let text = parsed
.choices
.into_iter()
.next()
.map(|c| c.message.content)
.ok_or_else(|| {
LlmError::Parse("openai-compat response contained no choices".to_string())
})?;
let raw_hash = blake3_hex(response_text.as_bytes());
let usage = parsed.usage.map(|u| TokenUsage {
prompt_tokens: u.prompt_tokens,
completion_tokens: u.completion_tokens,
});
Ok(LlmResponse {
text,
parsed_json: None,
model: model.to_string(),
usage,
raw_hash,
})
}
fn stream_openai_compat_sse(
base_url: String,
model: String,
api_key: Option<String>,
req: LlmRequest,
) -> BoxStream<'static> {
Box::pin(async_stream::stream! {
let timeout_ms = req.timeout_ms;
let result = tokio::task::spawn_blocking(move || {
call_openai_compat_streaming(&base_url, &model, api_key.as_deref(), &req, timeout_ms)
})
.await;
match result {
Ok(chunks) => {
for chunk in chunks {
yield chunk;
}
}
Err(e) => yield Err(LlmError::Transport(format!("spawn_blocking join error: {e}"))),
}
})
}
fn call_openai_compat_streaming(
base_url: &str,
model: &str,
api_key: Option<&str>,
req: &LlmRequest,
timeout_ms: u64,
) -> Vec<Result<StreamChunk, LlmError>> {
let url = format!("{base_url}/v1/chat/completions");
let messages: Vec<OpenAiMessage<'_>> = req
.messages
.iter()
.map(|m| OpenAiMessage {
role: role_to_str(m.role),
content: &m.content,
})
.collect();
let body = ChatCompletionRequest {
model,
messages,
stream: true,
max_tokens: req.max_tokens,
};
let body_value = match serde_json::to_value(&body) {
Ok(v) => v,
Err(e) => {
return vec![Err(LlmError::Transport(format!(
"request serialization failed: {e}"
)))]
}
};
let timeout = Duration::from_millis(timeout_ms);
let agent = ureq::AgentBuilder::new().timeout(timeout).build();
let mut request = agent.post(&url).set("content-type", "application/json");
if let Some(key) = api_key {
request = request.set("authorization", &format!("Bearer {key}"));
}
let raw_response = match request.send_json(body_value) {
Ok(r) => r,
Err(err) => return vec![Err(map_ureq_error(err, timeout_ms))],
};
let status = raw_response.status();
if status != 200 {
return vec![Err(LlmError::Upstream(format!("HTTP {status}")))];
}
let body_text = match raw_response.into_string() {
Ok(s) => s,
Err(e) => {
return vec![Err(LlmError::Transport(format!(
"reading streaming response body: {e}"
)))]
}
};
let mut chunks = Vec::new();
for line in body_text.lines() {
if line.is_empty() || line.starts_with("event:") {
continue;
}
let data = match line.strip_prefix("data:") {
Some(rest) => rest.trim(),
None => continue,
};
if data == "[DONE]" {
chunks.push(Ok(StreamChunk {
delta: String::new(),
finish_reason: Some("stop".into()),
}));
return chunks;
}
let envelope: StreamChunkEnvelope = match serde_json::from_str(data) {
Ok(v) => v,
Err(e) => {
chunks.push(Err(LlmError::Parse(format!(
"openai-compat SSE data parse: {e}: {data}"
))));
continue;
}
};
let choice = match envelope.choices.into_iter().next() {
Some(c) => c,
None => continue,
};
let finish_reason = choice.finish_reason;
let delta_text = choice.delta.content;
chunks.push(Ok(StreamChunk {
delta: delta_text,
finish_reason,
}));
}
chunks
}
fn map_ureq_error(err: ureq::Error, timeout_ms: u64) -> LlmError {
match err {
ureq::Error::Transport(t) => {
let msg = t.to_string();
if is_timeout_message(&msg) {
LlmError::Timeout { timeout_ms }
} else {
LlmError::Transport(msg)
}
}
ureq::Error::Status(code, _) => LlmError::Upstream(format!("HTTP {code}")),
}
}
fn is_timeout_message(msg: &str) -> bool {
let lower = msg.to_ascii_lowercase();
lower.contains("timed out") || lower.contains("deadline exceeded") || lower.contains("timeout")
}
fn role_to_str(role: LlmRole) -> &'static str {
match role {
LlmRole::User => "user",
LlmRole::Assistant => "assistant",
LlmRole::Tool => "tool",
}
}