Skip to main content

agent_line/
ctx.rs

1use crate::agent::StepError;
2use std::{collections::HashMap, env, sync::Arc};
3
4/// Execution context for agents (services, config, etc.)
5pub struct Ctx {
6    store: HashMap<String, String>,
7    log: Vec<String>,
8    llm_client: Arc<LlmClient>,
9}
10
11struct LlmClient {
12    base_url: String,
13    model: String,
14    num_ctx: u32,
15    api_key: Option<String>,
16    provider: Provider,
17}
18
19/// Builder for LLM chat requests. Obtained via [`Ctx::llm`].
20pub struct LlmRequestBuilder {
21    client: Arc<LlmClient>,
22    system: Option<String>,
23    messages: Vec<String>,
24}
25
26/// LLM provider, set via the `AGENT_LINE_PROVIDER` env var.
27#[derive(Debug, PartialEq)]
28pub enum Provider {
29    /// Ollama (default). Local inference, no API key needed.
30    Ollama,
31    /// OpenAI-compatible APIs (OpenRouter, etc.).
32    OpenAi,
33    /// Anthropic API.
34    Anthropic,
35}
36
37impl Provider {
38    /// Parse a provider name. Unrecognized values default to Ollama.
39    pub fn from_str(s: &str) -> Self {
40        match s.to_lowercase().as_str() {
41            "openai" => Provider::OpenAi,
42            "anthropic" => Provider::Anthropic,
43            _ => Provider::Ollama,
44        }
45    }
46
47    /// Return the full chat endpoint URL for this provider.
48    pub fn endpoint(&self, base_url: &str) -> String {
49        let base = base_url.trim_end_matches('/');
50        match self {
51            Provider::Ollama => format!("{base}/api/chat"),
52            Provider::OpenAi => format!("{base}/v1/chat/completions"),
53            Provider::Anthropic => format!("{base}/v1/messages"),
54        }
55    }
56
57    /// Extract the assistant message from a provider-specific JSON response.
58    pub fn parse_response(&self, json: &serde_json::Value) -> Result<String, StepError> {
59        let content = match self {
60            Provider::Ollama => json["message"]["content"].as_str(),
61            Provider::OpenAi => json["choices"][0]["message"]["content"].as_str(),
62            Provider::Anthropic => json["content"][0]["text"].as_str(),
63        };
64        content
65            .map(|s| s.to_string())
66            .ok_or_else(|| StepError::other("llm response missing message content"))
67    }
68}
69
70impl LlmRequestBuilder {
71    /// Set the system prompt.
72    pub fn system(mut self, msg: &str) -> Self {
73        self.system = Some(msg.to_string());
74        self
75    }
76
77    /// Append a user message.
78    pub fn user(mut self, msg: impl Into<String>) -> Self {
79        self.messages.push(msg.into());
80        self
81    }
82
83    /// Send the request and return the assistant's response text.
84    pub fn send(self) -> Result<String, StepError> {
85        let mut messages = Vec::new();
86
87        if let Some(sys) = &self.system {
88            messages.push(serde_json::json!({
89                "role": "system",
90                "content": sys
91            }));
92        }
93
94        for msg in &self.messages {
95            messages.push(serde_json::json!({
96                "role": "user",
97                "content": msg
98            }));
99        }
100
101        let body = match &self.client.provider {
102            Provider::Ollama => serde_json::json!({
103                "model": self.client.model,
104                "messages": messages,
105                "stream": false,
106                "options": {
107                    "num_ctx": self.client.num_ctx
108                }
109            }),
110            Provider::OpenAi => serde_json::json!({
111                "model": self.client.model,
112                "messages": messages,
113                "stream": false,
114                "max_tokens": self.client.num_ctx
115            }),
116            Provider::Anthropic => serde_json::json!({
117                "model": self.client.model,
118                "messages": messages,
119                "stream": false,
120                "max_tokens": self.client.num_ctx
121            }),
122        };
123
124        let url = self.client.provider.endpoint(&self.client.base_url);
125        let mut request = ureq::post(&url);
126
127        match &self.client.provider {
128            Provider::Anthropic => {
129                if let Some(key) = &self.client.api_key {
130                    request = request.header("x-api-key", key);
131                }
132                request = request.header("anthropic-version", "2023-06-01");
133                request = request.header("content-type", "application/json");
134            }
135            _ => {
136                if let Some(key) = &self.client.api_key {
137                    request = request.header("Authorization", &format!("Bearer {key}"));
138                }
139            }
140        }
141
142        if std::env::var("AGENT_LINE_DEBUG").is_ok() {
143            eprintln!("[debug] LLM request to {}", url);
144            eprintln!(
145                "[debug] Messages: {}",
146                serde_json::to_string_pretty(&messages).unwrap_or_default()
147            );
148        }
149
150        let mut response = request
151            .send_json(&body)
152            .map_err(|e| StepError::transient(format!("llm request failed: {e}")))?;
153
154        let json: serde_json::Value = response
155            .body_mut()
156            .read_json()
157            .map_err(|e| StepError::transient(format!("llm response parse failed: {e}")))?;
158
159        if std::env::var("AGENT_LINE_DEBUG").is_ok() {
160            eprintln!("[debug] LLM response: {}", &json);
161        }
162
163        self.client.provider.parse_response(&json)
164    }
165}
166
167impl Ctx {
168    /// Create a new context. Configuration is read from environment variables
169    /// (see the crate-level docs for the full list).
170    pub fn new() -> Self {
171        let model = env::var("AGENT_LINE_MODEL").unwrap_or_else(|_| "llama3.1:8b".to_string());
172        let base_url =
173            env::var("AGENT_LINE_LLM_URL").unwrap_or_else(|_| "http://localhost:11434".to_string());
174
175        let num_ctx = match env::var("AGENT_LINE_NUM_CTX") {
176            Ok(v) => v.parse::<u32>().unwrap_or(4096),
177            Err(_) => 4096,
178        };
179
180        let api_key = env::var("AGENT_LINE_API_KEY").ok();
181        let provider = Provider::from_str(
182            &env::var("AGENT_LINE_PROVIDER").unwrap_or_else(|_| "ollama".to_string()),
183        );
184
185        if env::var("AGENT_LINE_DEBUG").is_ok() {
186            eprintln!(
187                "[debug] provider: {:?}\n\
188                 [debug] model: {}\n\
189                 [debug] base_url: {}\n\
190                 [debug] num_ctx: {}\n\
191                 [debug] api_key: {}",
192                provider,
193                model,
194                base_url,
195                num_ctx,
196                if api_key.is_some() { "set" } else { "not set" },
197            );
198        }
199
200        Self {
201            store: HashMap::new(),
202            log: vec![],
203            llm_client: Arc::new(LlmClient {
204                base_url,
205                model,
206                num_ctx,
207                api_key,
208                provider,
209            }),
210        }
211    }
212
213    /// Insert or overwrite a key in the KV store.
214    pub fn set(&mut self, key: impl Into<String>, value: impl Into<String>) {
215        self.store.insert(key.into(), value.into());
216    }
217
218    /// Look up a key in the KV store.
219    pub fn get(&self, key: &str) -> Option<&str> {
220        self.store.get(key).map(|s| s.as_str())
221    }
222
223    /// Remove a key from the KV store, returning its value if it existed.
224    pub fn remove(&mut self, key: &str) -> Option<String> {
225        self.store.remove(key)
226    }
227
228    /// Append a message to the event log.
229    pub fn log(&mut self, msg: impl Into<String>) {
230        self.log.push(msg.into());
231    }
232
233    /// Return all log messages in order.
234    pub fn logs(&self) -> &[String] {
235        &self.log
236    }
237
238    /// Clear the event log, leaving the KV store intact.
239    pub fn clear_logs(&mut self) {
240        self.log.clear();
241    }
242
243    /// Clear both the KV store and the event log.
244    pub fn clear(&mut self) {
245        self.store.clear();
246        self.log.clear();
247    }
248
249    /// Start building an LLM chat request.
250    pub fn llm(&self) -> LlmRequestBuilder {
251        LlmRequestBuilder {
252            client: Arc::clone(&self.llm_client),
253            system: None,
254            messages: Vec::new(),
255        }
256    }
257}
258
259impl Default for Ctx {
260    fn default() -> Self {
261        Self::new()
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268
269    // --- Provider::from_str ---
270
271    #[test]
272    fn test_provider_from_str_ollama() {
273        assert_eq!(Provider::from_str("ollama"), Provider::Ollama);
274    }
275
276    #[test]
277    fn test_provider_from_str_openai() {
278        assert_eq!(Provider::from_str("openai"), Provider::OpenAi);
279    }
280
281    #[test]
282    fn test_provider_from_str_anthropic() {
283        assert_eq!(Provider::from_str("anthropic"), Provider::Anthropic);
284    }
285
286    #[test]
287    fn test_provider_from_str_case_insensitive() {
288        assert_eq!(Provider::from_str("OpenAI"), Provider::OpenAi);
289        assert_eq!(Provider::from_str("ANTHROPIC"), Provider::Anthropic);
290        assert_eq!(Provider::from_str("Ollama"), Provider::Ollama);
291    }
292
293    #[test]
294    fn test_provider_from_str_unknown_defaults_to_ollama() {
295        assert_eq!(Provider::from_str("something"), Provider::Ollama);
296    }
297
298    // --- Provider::endpoint ---
299
300    #[test]
301    fn test_ollama_endpoint() {
302        assert_eq!(
303            Provider::Ollama.endpoint("http://localhost:11434"),
304            "http://localhost:11434/api/chat"
305        );
306    }
307
308    #[test]
309    fn test_openai_endpoint() {
310        assert_eq!(
311            Provider::OpenAi.endpoint("https://openrouter.ai"),
312            "https://openrouter.ai/v1/chat/completions"
313        );
314    }
315
316    #[test]
317    fn test_anthropic_endpoint() {
318        assert_eq!(
319            Provider::Anthropic.endpoint("https://api.anthropic.com"),
320            "https://api.anthropic.com/v1/messages"
321        );
322    }
323
324    #[test]
325    fn test_endpoint_strips_trailing_slash() {
326        assert_eq!(
327            Provider::OpenAi.endpoint("https://openrouter.ai/"),
328            "https://openrouter.ai/v1/chat/completions"
329        );
330    }
331
332    // --- Provider::parse_response ---
333
334    #[test]
335    fn test_ollama_parse_response() {
336        let json = serde_json::json!({
337            "message": { "content": "Hello from Ollama" }
338        });
339        assert_eq!(
340            Provider::Ollama.parse_response(&json).unwrap(),
341            "Hello from Ollama"
342        );
343    }
344
345    #[test]
346    fn test_openai_parse_response() {
347        let json = serde_json::json!({
348            "choices": [{ "message": { "content": "Hello from OpenRouter" } }]
349        });
350        assert_eq!(
351            Provider::OpenAi.parse_response(&json).unwrap(),
352            "Hello from OpenRouter"
353        );
354    }
355
356    #[test]
357    fn test_anthropic_parse_response() {
358        let json = serde_json::json!({
359            "content": [{ "text": "Hello from Claude" }]
360        });
361        assert_eq!(
362            Provider::Anthropic.parse_response(&json).unwrap(),
363            "Hello from Claude"
364        );
365    }
366
367    #[test]
368    fn test_parse_response_missing_content_is_error() {
369        let json = serde_json::json!({"unexpected": "shape"});
370        assert!(Provider::Ollama.parse_response(&json).is_err());
371        assert!(Provider::OpenAi.parse_response(&json).is_err());
372        assert!(Provider::Anthropic.parse_response(&json).is_err());
373    }
374
375    // --- KV store ---
376
377    #[test]
378    fn set_then_get() {
379        let mut ctx = Ctx::new();
380        ctx.set("key", "value");
381        assert_eq!(ctx.get("key"), Some("value"));
382    }
383
384    #[test]
385    fn get_missing_key() {
386        let ctx = Ctx::new();
387        assert_eq!(ctx.get("nope"), None);
388    }
389
390    #[test]
391    fn set_overwrites() {
392        let mut ctx = Ctx::new();
393        ctx.set("key", "first");
394        ctx.set("key", "second");
395        assert_eq!(ctx.get("key"), Some("second"));
396    }
397
398    #[test]
399    fn remove_returns_value() {
400        let mut ctx = Ctx::new();
401        ctx.set("key", "value");
402        assert_eq!(ctx.remove("key"), Some("value".to_string()));
403        assert_eq!(ctx.get("key"), None);
404    }
405
406    #[test]
407    fn remove_missing_key() {
408        let mut ctx = Ctx::new();
409        assert_eq!(ctx.remove("nope"), None);
410    }
411
412    // --- Logging ---
413
414    #[test]
415    fn log_appends_and_logs_returns_in_order() {
416        let mut ctx = Ctx::new();
417        ctx.log("first");
418        ctx.log("second");
419        ctx.log("third");
420        assert_eq!(ctx.logs(), &["first", "second", "third"]);
421    }
422
423    #[test]
424    fn clear_logs_preserves_store() {
425        let mut ctx = Ctx::new();
426        ctx.set("key", "value");
427        ctx.log("msg");
428        ctx.clear_logs();
429        assert!(ctx.logs().is_empty());
430        assert_eq!(ctx.get("key"), Some("value"));
431    }
432
433    #[test]
434    fn clear_empties_both() {
435        let mut ctx = Ctx::new();
436        ctx.set("key", "value");
437        ctx.log("msg");
438        ctx.clear();
439        assert!(ctx.logs().is_empty());
440        assert_eq!(ctx.get("key"), None);
441    }
442}