use base64::engine::general_purpose::STANDARD as B64;
use base64::Engine;
use ollama_rs::Ollama;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use url::Url;
use crate::config::{AuthConfig, RuntimeConfig, RuntimeMode};
use crate::error::{Result, RuntimeError};
use crate::guard::ExecutionGuard;
#[cfg(feature = "stream")]
use crate::guard::GuardedStream;
pub struct OllamaRuntime {
client: Ollama,
guard: ExecutionGuard,
auto_pull: bool,
mode: RuntimeMode,
}
impl OllamaRuntime {
pub async fn new(config: RuntimeConfig) -> Result<Self> {
config.validate()?;
let base = parse_and_normalize_base_url(&config.base_url)?;
let port = base.port().ok_or_else(|| {
RuntimeError::Other("internal error: base_url normalization did not set a port".into())
})?;
let reqwest_client = build_reqwest_client(&config)?;
let client = Ollama::new_with_client(base.as_str(), port, reqwest_client);
let guard = ExecutionGuard::new(config.max_concurrent, config.timeout, config.max_retries);
Ok(Self {
client,
guard,
auto_pull: config.auto_pull,
mode: config.mode,
})
}
pub fn client(&self) -> &Ollama {
&self.client
}
pub fn guard(&self) -> &ExecutionGuard {
&self.guard
}
pub fn auto_pull(&self) -> bool {
self.auto_pull
}
pub fn mode(&self) -> RuntimeMode {
self.mode
}
pub async fn ensure(&self, model: &str) -> Result<()> {
crate::model::ensure_local_model(self.auto_pull, self.mode, self.client(), model).await
}
pub async fn call<F, Fut, T>(&self, f: F) -> Result<T>
where
F: Fn(&Ollama) -> Fut,
Fut: std::future::Future<Output = ollama_rs::error::Result<T>>,
{
self.guard().run(|| f(self.client())).await
}
#[cfg(feature = "stream")]
pub async fn call_stream<F, Fut, S>(&self, f: F) -> Result<GuardedStream<S>>
where
F: Fn(&Ollama) -> Fut,
Fut: std::future::Future<Output = ollama_rs::error::Result<S>>,
S: tokio_stream::Stream + Unpin,
{
self.guard().run_stream(|| f(self.client())).await
}
}
fn parse_and_normalize_base_url(raw: &str) -> Result<Url> {
let mut url =
Url::parse(raw).map_err(|e| RuntimeError::Other(format!("invalid base_url: {e}")))?;
if url.host().is_none() {
return Err(RuntimeError::Other("base_url must include a host".into()));
}
if url.port().is_none() {
let port = match url.scheme() {
"http" => 11434u16,
"https" => 443u16,
other => {
return Err(RuntimeError::Other(format!(
"unsupported URL scheme: {other} (expected http or https)"
)));
}
};
url.set_port(Some(port)).map_err(|_| {
RuntimeError::Other("invalid base_url: could not apply default port".into())
})?;
}
let path = url.path();
if path.is_empty() || path == "/" {
url.set_path("/");
} else if !path.ends_with('/') {
url.set_path(&format!("{}/", path.trim_end_matches('/')));
}
Ok(url)
}
fn build_reqwest_client(config: &RuntimeConfig) -> Result<reqwest::Client> {
let mut headers = HeaderMap::new();
if let Some(auth) = &config.auth {
merge_auth_headers(&mut headers, auth)?;
}
reqwest::Client::builder()
.default_headers(headers)
.connect_timeout(config.connect_timeout)
.timeout(config.timeout)
.build()
.map_err(|e| RuntimeError::Other(format!("failed to build HTTP client: {e}")))
}
fn merge_auth_headers(headers: &mut HeaderMap, auth: &AuthConfig) -> Result<()> {
match auth {
AuthConfig::BearerToken(token) => {
let value = HeaderValue::try_from(format!("Bearer {token}")).map_err(header_err)?;
headers.insert(HeaderName::from_static("authorization"), value);
}
AuthConfig::Basic { username, password } => {
let creds = format!("{}:{}", username, password);
let b64 = B64.encode(creds.as_bytes());
let value = HeaderValue::try_from(format!("Basic {b64}")).map_err(header_err)?;
headers.insert(HeaderName::from_static("authorization"), value);
}
AuthConfig::CustomHeader { key, value } => {
let name = HeaderName::try_from(key.as_str()).map_err(header_err)?;
let val = HeaderValue::try_from(value.as_str()).map_err(header_err)?;
headers.insert(name, val);
}
}
Ok(())
}
fn header_err<E: std::fmt::Display>(e: E) -> RuntimeError {
RuntimeError::Other(format!("invalid auth header: {e}"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn normalizes_empty_path_and_port() {
let u = parse_and_normalize_base_url("http://127.0.0.1").unwrap();
assert_eq!(u.as_str(), "http://127.0.0.1:11434/");
}
#[test]
fn normalizes_subpath_trailing_slash() {
let u = parse_and_normalize_base_url("http://proxy/v1/ollama").unwrap();
assert!(u.as_str().ends_with("/v1/ollama/"));
}
}