ollama-kit 0.1.0

Runtime control (lifecycle + execution guards) for ollama-rs without wrapping its API.
Documentation
use base64::engine::general_purpose::STANDARD as B64;
use base64::Engine;
use ollama_rs::Ollama;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use url::Url;

use crate::config::{AuthConfig, RuntimeConfig};
use crate::error::{Result, RuntimeError};
use crate::guard::ExecutionGuard;
#[cfg(feature = "stream")]
use crate::guard::GuardedStream;
use crate::model::ModelManager;

/// Coordinates a shared [`Ollama`] client, model lifecycle helpers, and execution guards.
pub struct OllamaRuntime {
    client: Ollama,
    manager: ModelManager,
    guard: ExecutionGuard,
}

impl OllamaRuntime {
    /// Validates config, builds an HTTP client (timeouts + optional auth), then constructs
    /// [`Ollama`] via [`Ollama::new_with_client`].
    pub async fn new(config: RuntimeConfig) -> Result<Self> {
        config.validate()?;
        let base = parse_and_normalize_base_url(&config.base_url)?;
        let port = base.port().ok_or_else(|| {
            RuntimeError::Other("internal error: base_url normalization did not set a port".into())
        })?;

        let reqwest_client = build_reqwest_client(&config)?;

        let client = Ollama::new_with_client(base.as_str(), port, reqwest_client);

        let guard = ExecutionGuard::new(config.max_concurrent, config.timeout, config.max_retries);

        let manager = ModelManager::new(client.clone(), config.auto_pull, config.mode);

        Ok(Self {
            client,
            manager,
            guard,
        })
    }

    pub fn client(&self) -> &Ollama {
        &self.client
    }

    pub fn guard(&self) -> &ExecutionGuard {
        &self.guard
    }

    pub fn models(&self) -> &ModelManager {
        &self.manager
    }

    /// Shortcut for [`ModelManager::ensure`]. Equivalent to `self.models().ensure(model).await`.
    pub async fn ensure_model(&self, model: &str) -> Result<()> {
        self.models().ensure(model).await
    }

    /// Shortcut for [`ExecutionGuard::run`]. Equivalent to `self.guard().run(f).await`.
    ///
    /// Prefer this at call sites when you only need guarded execution once; keep using
    /// [`Self::guard`] when you want a reusable handle (for example attaching metrics to the
    /// guard or holding [`ExecutionGuard::max_retries`] in scope).
    pub async fn run<F, Fut, T>(&self, f: F) -> Result<T>
    where
        F: Fn() -> Fut,
        Fut: std::future::Future<Output = ollama_rs::error::Result<T>>,
    {
        self.guard().run(f).await
    }

    /// Shortcut for [`crate::guard::ExecutionGuard::run_stream`] when built with **`stream`**.
    #[cfg(feature = "stream")]
    pub async fn run_stream<F, Fut, S>(&self, f: F) -> Result<GuardedStream<S>>
    where
        F: Fn() -> Fut,
        Fut: std::future::Future<Output = ollama_rs::error::Result<S>>,
        S: tokio_stream::Stream + Unpin,
    {
        self.guard().run_stream(f).await
    }
}

fn parse_and_normalize_base_url(raw: &str) -> Result<Url> {
    let mut url =
        Url::parse(raw).map_err(|e| RuntimeError::Other(format!("invalid base_url: {e}")))?;

    if url.host().is_none() {
        return Err(RuntimeError::Other("base_url must include a host".into()));
    }

    if url.port().is_none() {
        let port = match url.scheme() {
            "http" => 11434u16,
            "https" => 443u16,
            other => {
                return Err(RuntimeError::Other(format!(
                    "unsupported URL scheme: {other} (expected http or https)"
                )));
            }
        };
        url.set_port(Some(port)).map_err(|_| {
            RuntimeError::Other("invalid base_url: could not apply default port".into())
        })?;
    }

    let path = url.path();
    if path.is_empty() || path == "/" {
        url.set_path("/");
    } else if !path.ends_with('/') {
        url.set_path(&format!("{}/", path.trim_end_matches('/')));
    }

    Ok(url)
}

fn build_reqwest_client(config: &RuntimeConfig) -> Result<reqwest::Client> {
    let mut headers = HeaderMap::new();
    if let Some(auth) = &config.auth {
        merge_auth_headers(&mut headers, auth)?;
    }

    reqwest::Client::builder()
        .default_headers(headers)
        .connect_timeout(config.connect_timeout)
        .timeout(config.timeout)
        .build()
        .map_err(|e| RuntimeError::Other(format!("failed to build HTTP client: {e}")))
}

fn merge_auth_headers(headers: &mut HeaderMap, auth: &AuthConfig) -> Result<()> {
    match auth {
        AuthConfig::BearerToken(token) => {
            let value = HeaderValue::try_from(format!("Bearer {token}")).map_err(header_err)?;
            headers.insert(HeaderName::from_static("authorization"), value);
        }
        AuthConfig::Basic { username, password } => {
            let creds = format!("{}:{}", username, password);
            let b64 = B64.encode(creds.as_bytes());
            let value = HeaderValue::try_from(format!("Basic {b64}")).map_err(header_err)?;
            headers.insert(HeaderName::from_static("authorization"), value);
        }
        AuthConfig::CustomHeader { key, value } => {
            let name = HeaderName::try_from(key.as_str()).map_err(header_err)?;
            let val = HeaderValue::try_from(value.as_str()).map_err(header_err)?;
            headers.insert(name, val);
        }
    }
    Ok(())
}

fn header_err<E: std::fmt::Display>(e: E) -> RuntimeError {
    RuntimeError::Other(format!("invalid auth header: {e}"))
}

#[cfg(test)]
mod tests {
    use super::*;
    #[test]
    fn normalizes_empty_path_and_port() {
        let u = parse_and_normalize_base_url("http://127.0.0.1").unwrap();
        assert_eq!(u.as_str(), "http://127.0.0.1:11434/");
    }

    #[test]
    fn normalizes_subpath_trailing_slash() {
        let u = parse_and_normalize_base_url("http://proxy/v1/ollama").unwrap();
        assert!(u.as_str().ends_with("/v1/ollama/"));
    }
}