mps-rs 1.8.0

MPS — plain-text personal productivity CLI (Rust)
Documentation
pub mod context;
pub mod session;

use anyhow::{anyhow, bail, Context as _};
use futures_util::StreamExt;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;

/// A single chat message (role + content), compatible with the OpenAI chat format.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
    pub role: String,
    pub content: String,
}

impl Message {
    pub fn system(content: impl Into<String>) -> Self {
        Self { role: "system".into(), content: content.into() }
    }
    pub fn user(content: impl Into<String>) -> Self {
        Self { role: "user".into(), content: content.into() }
    }
    pub fn assistant(content: impl Into<String>) -> Self {
        Self { role: "assistant".into(), content: content.into() }
    }
}

/// OpenAI-compatible streaming LLM client.
pub struct LlmClient {
    pub base_url: String,
    pub model: String,
    api_key: String,
    client: Client,
}

impl LlmClient {
    pub fn new(url: &str, model: &str, api_key: &str) -> Self {
        let client = Client::builder()
            .timeout(Duration::from_secs(120))
            .build()
            .expect("reqwest client");
        Self {
            base_url: url.trim_end_matches('/').to_string(),
            model: model.to_string(),
            api_key: api_key.to_string(),
            client,
        }
    }

    /// POST to /v1/chat/completions with stream:true, yield content delta strings.
    pub async fn chat_stream(
        &self,
        messages: &[Message],
    ) -> anyhow::Result<impl futures_util::Stream<Item = anyhow::Result<String>>> {
        #[derive(Serialize)]
        struct Req<'a> {
            model: &'a str,
            messages: &'a [Message],
            stream: bool,
        }

        let url = format!("{}/v1/chat/completions", self.base_url);
        let mut req = self
            .client
            .post(&url)
            .header("Content-Type", "application/json")
            .json(&Req { model: &self.model, messages, stream: true });

        if !self.api_key.is_empty() {
            req = req.header("Authorization", format!("Bearer {}", self.api_key));
        }

        let response = req
            .send()
            .await
            .with_context(|| format!("Cannot reach LLM at {} — is ollama/llama-server running?", self.base_url))?;

        if !response.status().is_success() {
            let status = response.status();
            let body = response.text().await.unwrap_or_default();
            bail!("LLM returned {}: {}", status, body);
        }

        let stream = response.bytes_stream();

        // Buffer partial lines across HTTP chunks so a JSON payload split over two
        // chunks is reassembled before being parsed.
        let parsed = futures_util::stream::unfold(
            (stream, String::new()),
            |(mut stream, mut buf)| Box::pin(async move {
                loop {
                    // Drain any complete lines already in the buffer.
                    if let Some(nl) = buf.find('\n') {
                        let line = buf[..nl].trim().to_string();
                        buf = buf[nl + 1..].to_string();
                        if let Some(data) = line.strip_prefix("data: ") {
                            if data == "[DONE]" {
                                continue;
                            }
                            let result = parse_sse_delta(data);
                            return Some((result, (stream, buf)));
                        }
                        continue;
                    }
                    // Need more bytes.
                    match stream.next().await {
                        None => return None,
                        Some(Err(e)) => {
                            return Some((Err(anyhow!("stream error: {}", e)), (stream, buf)));
                        }
                        Some(Ok(bytes)) => {
                            buf.push_str(&String::from_utf8_lossy(&bytes));
                        }
                    }
                }
            }),
        );

        Ok(parsed)
    }

    /// Probe :11434 (Ollama) then :8080 (llama.cpp). Returns first base URL that responds.
    pub async fn auto_detect() -> Option<String> {
        let probe_client = Client::builder()
            .timeout(Duration::from_secs(2))
            .build()
            .ok()?;
        for port in [11434u16, 8080] {
            let url = format!("http://localhost:{}/v1/models", port);
            if probe_client.get(&url).send().await.is_ok() {
                return Some(format!("http://localhost:{}", port));
            }
        }
        None
    }
}

/// Extract `choices[0].delta.content` from a single SSE data line.
pub fn parse_sse_delta(data: &str) -> anyhow::Result<String> {
    #[derive(Deserialize)]
    struct Delta {
        content: Option<String>,
    }
    #[derive(Deserialize)]
    struct Choice {
        delta: Delta,
    }
    #[derive(Deserialize)]
    struct Resp {
        choices: Vec<Choice>,
    }

    let resp: Resp = serde_json::from_str(data)
        .with_context(|| format!("invalid SSE JSON: {}", data))?;
    Ok(resp
        .choices
        .into_iter()
        .next()
        .and_then(|c| c.delta.content)
        .unwrap_or_default())
}