Skip to main content

mps/llm/
mod.rs

1pub mod context;
2pub mod session;
3
4use anyhow::{anyhow, bail, Context as _};
5use futures_util::StreamExt;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use std::time::Duration;
9
10/// A single chat message (role + content), compatible with the OpenAI chat format.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Message {
13    pub role: String,
14    pub content: String,
15}
16
17impl Message {
18    pub fn system(content: impl Into<String>) -> Self {
19        Self {
20            role: "system".into(),
21            content: content.into(),
22        }
23    }
24    pub fn user(content: impl Into<String>) -> Self {
25        Self {
26            role: "user".into(),
27            content: content.into(),
28        }
29    }
30    pub fn assistant(content: impl Into<String>) -> Self {
31        Self {
32            role: "assistant".into(),
33            content: content.into(),
34        }
35    }
36}
37
38/// OpenAI-compatible streaming LLM client.
39pub struct LlmClient {
40    pub base_url: String,
41    pub model: String,
42    api_key: String,
43    client: Client,
44}
45
46impl LlmClient {
47    pub fn new(url: &str, model: &str, api_key: &str, connect_timeout_secs: u64) -> Self {
48        let client = Client::builder()
49            .connect_timeout(Duration::from_secs(connect_timeout_secs))
50            .build()
51            .expect("reqwest client");
52        Self {
53            base_url: url.trim_end_matches('/').to_string(),
54            model: model.to_string(),
55            api_key: api_key.to_string(),
56            client,
57        }
58    }
59
60    /// POST to /v1/chat/completions with stream:true, yield content delta strings.
61    pub async fn chat_stream(
62        &self,
63        messages: &[Message],
64    ) -> anyhow::Result<impl futures_util::Stream<Item = anyhow::Result<String>>> {
65        #[derive(Serialize)]
66        struct Req<'a> {
67            model: &'a str,
68            messages: &'a [Message],
69            stream: bool,
70        }
71
72        let url = format!("{}/v1/chat/completions", self.base_url);
73        let mut req = self
74            .client
75            .post(&url)
76            .header("Content-Type", "application/json")
77            .json(&Req {
78                model: &self.model,
79                messages,
80                stream: true,
81            });
82
83        if !self.api_key.is_empty() {
84            req = req.header("Authorization", format!("Bearer {}", self.api_key));
85        }
86
87        let response = req.send().await.with_context(|| {
88            format!(
89                "Cannot reach LLM at {} — is ollama/llama-server running?",
90                self.base_url
91            )
92        })?;
93
94        if !response.status().is_success() {
95            let status = response.status();
96            let body = response.text().await.unwrap_or_default();
97            bail!("LLM returned {}: {}", status, body);
98        }
99
100        let stream = response.bytes_stream();
101
102        // Buffer partial lines across HTTP chunks so a JSON payload split over two
103        // chunks is reassembled before being parsed.
104        let parsed =
105            futures_util::stream::unfold((stream, String::new()), |(mut stream, mut buf)| {
106                Box::pin(async move {
107                    loop {
108                        // Drain any complete lines already in the buffer.
109                        if let Some(nl) = buf.find('\n') {
110                            let line = buf[..nl].trim().to_string();
111                            buf = buf[nl + 1..].to_string();
112                            if let Some(data) = line.strip_prefix("data: ") {
113                                if data == "[DONE]" {
114                                    continue;
115                                }
116                                let result = parse_sse_delta(data);
117                                return Some((result, (stream, buf)));
118                            }
119                            continue;
120                        }
121                        // Need more bytes.
122                        match stream.next().await {
123                            None => return None,
124                            Some(Err(e)) => {
125                                return Some((Err(anyhow!("stream error: {}", e)), (stream, buf)));
126                            }
127                            Some(Ok(bytes)) => {
128                                buf.push_str(&String::from_utf8_lossy(&bytes));
129                            }
130                        }
131                    }
132                })
133            });
134
135        Ok(parsed)
136    }
137
138    /// Probe :11434 (Ollama) then :8080 (llama.cpp). Returns first base URL that responds.
139    pub async fn auto_detect() -> Option<String> {
140        let probe_client = Client::builder()
141            .timeout(Duration::from_secs(2))
142            .build()
143            .ok()?;
144        for port in [11434u16, 8080] {
145            let url = format!("http://localhost:{}/v1/models", port);
146            if probe_client.get(&url).send().await.is_ok() {
147                return Some(format!("http://localhost:{}", port));
148            }
149        }
150        None
151    }
152}
153
154/// Extract `choices[0].delta.content` from a single SSE data line.
155pub fn parse_sse_delta(data: &str) -> anyhow::Result<String> {
156    #[derive(Deserialize)]
157    struct Delta {
158        content: Option<String>,
159    }
160    #[derive(Deserialize)]
161    struct Choice {
162        delta: Delta,
163    }
164    #[derive(Deserialize)]
165    struct Resp {
166        choices: Vec<Choice>,
167    }
168
169    let resp: Resp =
170        serde_json::from_str(data).with_context(|| format!("invalid SSE JSON: {}", data))?;
171    Ok(resp
172        .choices
173        .into_iter()
174        .next()
175        .and_then(|c| c.delta.content)
176        .unwrap_or_default())
177}