use super::sse_pump::pump_openai_sse;
use super::wire::tools_wire;
use crate::ChatMessage;
use crate::chat::{ChatEvent, ChatProvider, ToolDef};
use anyhow::{Context, Result, anyhow};
use async_trait::async_trait;
use tokio::sync::mpsc::Sender;
const LOCAL_PROBE_TIMEOUT_SECS: u64 = 1;
const LOCAL_REQUEST_TIMEOUT_SECS: u64 = 120;
const OPENROUTER_URL: &str = "https://openrouter.ai/api/v1/chat/completions";
const OPENROUTER_CONNECT_TIMEOUT_SECS: u64 = 10;
const OPENROUTER_REQUEST_TIMEOUT_SECS: u64 = 120;
const HTTP_REFERER: &str = "https://github.com/bobmatnyc/trusty-common";
const X_TITLE: &str = "trusty-common";
pub struct OpenRouterProvider {
pub api_key: String,
pub model: String,
}
impl OpenRouterProvider {
pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
model: model.into(),
}
}
}
#[async_trait]
impl ChatProvider for OpenRouterProvider {
fn name(&self) -> &str {
"openrouter"
}
fn model(&self) -> &str {
&self.model
}
async fn chat_stream(
&self,
messages: Vec<ChatMessage>,
tools: Vec<ToolDef>,
tx: Sender<ChatEvent>,
) -> Result<()> {
if self.api_key.is_empty() {
return Err(anyhow!("openrouter api key is empty"));
}
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(
OPENROUTER_CONNECT_TIMEOUT_SECS,
))
.timeout(std::time::Duration::from_secs(
OPENROUTER_REQUEST_TIMEOUT_SECS,
))
.build()
.context("build reqwest client for OpenRouterProvider::chat_stream")?;
let tw = tools_wire(&tools);
let body = super::wire::ChatRequestWire {
model: &self.model,
messages: &messages,
stream: true,
tools: tw,
};
let resp = client
.post(OPENROUTER_URL)
.bearer_auth(&self.api_key)
.header("HTTP-Referer", HTTP_REFERER)
.header("X-Title", X_TITLE)
.json(&body)
.send()
.await
.context("POST openrouter chat completions (stream)")?;
let status = resp.status();
if !status.is_success() {
let text = resp.text().await.unwrap_or_default();
return Err(anyhow!("openrouter HTTP {status}: {text}"));
}
pump_openai_sse(resp, tx).await
}
}
pub struct OllamaProvider {
pub base_url: String,
pub model: String,
}
impl OllamaProvider {
pub fn new(base_url: impl Into<String>, model: impl Into<String>) -> Self {
Self {
base_url: base_url.into(),
model: model.into(),
}
}
}
#[async_trait]
impl ChatProvider for OllamaProvider {
fn name(&self) -> &str {
"ollama"
}
fn model(&self) -> &str {
&self.model
}
async fn chat_stream(
&self,
messages: Vec<ChatMessage>,
tools: Vec<ToolDef>,
tx: Sender<ChatEvent>,
) -> Result<()> {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(LOCAL_PROBE_TIMEOUT_SECS))
.timeout(std::time::Duration::from_secs(LOCAL_REQUEST_TIMEOUT_SECS))
.build()
.context("build reqwest client for OllamaProvider::chat_stream")?;
let url = format!(
"{}/v1/chat/completions",
self.base_url.trim_end_matches('/')
);
let tw = tools_wire(&tools);
let body = super::wire::ChatRequestWire {
model: &self.model,
messages: &messages,
stream: true,
tools: tw,
};
let resp = client
.post(&url)
.json(&body)
.send()
.await
.with_context(|| format!("POST {url}"))?;
let status = resp.status();
if !status.is_success() {
let text = resp.text().await.unwrap_or_default();
return Err(anyhow!("local chat HTTP {status}: {text}"));
}
pump_openai_sse(resp, tx).await
}
}
pub async fn auto_detect_local_provider(base_url: &str) -> Option<OllamaProvider> {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(LOCAL_PROBE_TIMEOUT_SECS))
.timeout(std::time::Duration::from_secs(LOCAL_PROBE_TIMEOUT_SECS))
.build()
.ok()?;
let url = format!("{}/v1/models", base_url.trim_end_matches('/'));
match client.get(&url).send().await {
Ok(resp) if resp.status().is_success() => {
Some(OllamaProvider::new(base_url.to_string(), String::new()))
}
_ => None,
}
}