Skip to main content

agent_line/
llm.rs

1use crate::agent::StepError;
2use std::{env, fmt, sync::Arc};
3
4/// Reusable LLM configuration. Each agent that needs an LLM holds its own
5/// `LlmConfig` and calls [`LlmConfig::request`] to start a chat request.
6///
7/// Build one with [`LlmConfig::builder`] for explicit settings, or with
8/// [`LlmConfig::from_env`] to read from `AGENT_LINE_*` environment variables.
9/// Multiple agents can share one config or each hold their own (cheap fast
10/// model for one step, strong reasoning model for another).
11#[derive(Clone, PartialEq, Eq)]
12pub struct LlmConfig {
13    base_url: String,
14    model: String,
15    num_ctx: u32,
16    max_tokens: u32,
17    api_key: Option<String>,
18    provider: Provider,
19}
20
21impl fmt::Debug for LlmConfig {
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        f.debug_struct("LlmConfig")
24            .field("provider", &self.provider)
25            .field("base_url", &self.base_url)
26            .field("model", &self.model)
27            .field("num_ctx", &self.num_ctx)
28            .field("max_tokens", &self.max_tokens)
29            .field(
30                "api_key",
31                &if self.api_key.is_some() {
32                    "set"
33                } else {
34                    "not set"
35                },
36            )
37            .finish()
38    }
39}
40
41/// Error returned when building an [`LlmConfig`] without required fields.
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum LlmConfigError {
44    /// No provider was configured.
45    MissingProvider,
46    /// No base URL was configured.
47    MissingBaseUrl,
48    /// No model name was configured.
49    MissingModel,
50}
51
52impl fmt::Display for LlmConfigError {
53    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54        match self {
55            Self::MissingProvider => write!(f, "LlmConfig missing provider"),
56            Self::MissingBaseUrl => write!(f, "LlmConfig missing base_url"),
57            Self::MissingModel => write!(f, "LlmConfig missing model"),
58        }
59    }
60}
61
62impl std::error::Error for LlmConfigError {}
63
64/// Builder for [`LlmConfig`].
65#[derive(Default)]
66pub struct LlmConfigBuilder {
67    provider: Option<Provider>,
68    base_url: Option<String>,
69    model: Option<String>,
70    api_key: Option<String>,
71    num_ctx: Option<u32>,
72    max_tokens: Option<u32>,
73}
74
75/// Builder for LLM chat requests. Obtained via [`LlmConfig::request`].
76pub struct LlmRequestBuilder {
77    config: Arc<LlmConfig>,
78    system: Option<String>,
79    messages: Vec<String>,
80}
81
82/// LLM provider. Selected via [`LlmConfigBuilder::provider`] or the
83/// `AGENT_LINE_PROVIDER` env var when using [`LlmConfig::from_env`].
84#[non_exhaustive]
85#[derive(Clone, Copy, Debug, PartialEq, Eq)]
86pub enum Provider {
87    /// Ollama (default). Local inference, no API key needed.
88    Ollama,
89    /// OpenAI-compatible APIs (OpenRouter, etc.).
90    OpenAi,
91    /// Anthropic API.
92    Anthropic,
93}
94
95impl Provider {
96    /// Parse a provider name. Unrecognized values default to Ollama.
97    pub(crate) fn from_str(s: &str) -> Self {
98        match s.to_lowercase().as_str() {
99            "openai" => Provider::OpenAi,
100            "anthropic" => Provider::Anthropic,
101            _ => Provider::Ollama,
102        }
103    }
104
105    pub(crate) fn endpoint(&self, base_url: &str) -> String {
106        let base = base_url.trim_end_matches('/');
107        match self {
108            Provider::Ollama => format!("{base}/api/chat"),
109            Provider::OpenAi => format!("{base}/v1/chat/completions"),
110            Provider::Anthropic => format!("{base}/v1/messages"),
111        }
112    }
113
114    pub(crate) fn parse_response(&self, json: &serde_json::Value) -> Result<String, StepError> {
115        let content = match self {
116            Provider::Ollama => json["message"]["content"].as_str(),
117            Provider::OpenAi => json["choices"][0]["message"]["content"].as_str(),
118            Provider::Anthropic => json["content"][0]["text"].as_str(),
119        };
120        content
121            .map(|s| s.to_string())
122            .ok_or_else(|| StepError::other("llm response missing message content"))
123    }
124}
125
126impl LlmConfig {
127    /// Start building an explicit LLM configuration.
128    pub fn builder() -> LlmConfigBuilder {
129        LlmConfigBuilder::default()
130    }
131
132    /// Build an LLM configuration from `AGENT_LINE_*` environment variables.
133    ///
134    /// Reads `AGENT_LINE_PROVIDER`, `AGENT_LINE_LLM_URL`, `AGENT_LINE_MODEL`,
135    /// `AGENT_LINE_API_KEY`, `AGENT_LINE_NUM_CTX` (Ollama context window),
136    /// and `AGENT_LINE_MAX_TOKENS` (OpenAI/Anthropic response cap; falls back
137    /// to `AGENT_LINE_NUM_CTX` if unset). Defaults to a local Ollama
138    /// configuration when nothing is set.
139    ///
140    /// If `AGENT_LINE_DEBUG` is set, the resolved config is logged to stderr
141    /// once.
142    pub fn from_env() -> Self {
143        let num_ctx = match env::var("AGENT_LINE_NUM_CTX") {
144            Ok(v) => v.parse::<u32>().unwrap_or(4096),
145            Err(_) => 4096,
146        };
147        let max_tokens = match env::var("AGENT_LINE_MAX_TOKENS") {
148            Ok(v) => v.parse::<u32>().unwrap_or(num_ctx),
149            Err(_) => num_ctx,
150        };
151
152        let config = Self {
153            provider: Provider::from_str(
154                &env::var("AGENT_LINE_PROVIDER").unwrap_or_else(|_| "ollama".to_string()),
155            ),
156            base_url: env::var("AGENT_LINE_LLM_URL")
157                .unwrap_or_else(|_| "http://localhost:11434".to_string()),
158            model: env::var("AGENT_LINE_MODEL").unwrap_or_else(|_| "llama3.1:8b".to_string()),
159            api_key: env::var("AGENT_LINE_API_KEY").ok(),
160            num_ctx,
161            max_tokens,
162        };
163        config.debug_log();
164        config
165    }
166
167    /// Return a copy of this config with a different model name. All other
168    /// fields (provider, base URL, API key, token budgets) are preserved.
169    pub fn with_model(mut self, model: impl Into<String>) -> Self {
170        self.model = model.into();
171        self
172    }
173
174    /// Start building an LLM chat request that uses this config.
175    ///
176    /// Each call creates a fresh [`LlmRequestBuilder`]; chain `.system()`,
177    /// `.user()`, and `.send()` on the result. The config itself is not
178    /// consumed, so an agent can call `self.llm.request()` repeatedly.
179    pub fn request(&self) -> LlmRequestBuilder {
180        LlmRequestBuilder {
181            config: Arc::new(self.clone()),
182            system: None,
183            messages: Vec::new(),
184        }
185    }
186
187    fn debug_log(&self) {
188        if env::var("AGENT_LINE_DEBUG").is_ok() {
189            eprintln!(
190                "[debug] provider: {:?}\n\
191                 [debug] model: {}\n\
192                 [debug] base_url: {}\n\
193                 [debug] num_ctx: {}\n\
194                 [debug] max_tokens: {}\n\
195                 [debug] api_key: {}",
196                self.provider,
197                self.model,
198                self.base_url,
199                self.num_ctx,
200                self.max_tokens,
201                if self.api_key.is_some() {
202                    "set"
203                } else {
204                    "not set"
205                },
206            );
207        }
208    }
209}
210
211impl LlmConfigBuilder {
212    /// Set the LLM provider. Required.
213    pub fn provider(mut self, provider: Provider) -> Self {
214        self.provider = Some(provider);
215        self
216    }
217
218    /// Set the base URL of the LLM endpoint. Required.
219    pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
220        self.base_url = Some(base_url.into());
221        self
222    }
223
224    /// Set the model name. Required.
225    pub fn model(mut self, model: impl Into<String>) -> Self {
226        self.model = Some(model.into());
227        self
228    }
229
230    /// Set the API key. Optional for local providers.
231    pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
232        self.api_key = Some(api_key.into());
233        self
234    }
235
236    /// Set the context window size sent in the `options.num_ctx` field of
237    /// Ollama requests. Defaults to 4096. Ignored by OpenAI-compatible and
238    /// Anthropic providers; use [`max_tokens`](Self::max_tokens) for those.
239    pub fn num_ctx(mut self, num_ctx: u32) -> Self {
240        self.num_ctx = Some(num_ctx);
241        self
242    }
243
244    /// Set the maximum number of generated tokens sent in the `max_tokens`
245    /// field of OpenAI-compatible and Anthropic requests. Defaults to 4096.
246    /// Ignored by Ollama; use [`num_ctx`](Self::num_ctx) for that.
247    pub fn max_tokens(mut self, max_tokens: u32) -> Self {
248        self.max_tokens = Some(max_tokens);
249        self
250    }
251
252    /// Build the [`LlmConfig`].
253    pub fn build(self) -> Result<LlmConfig, LlmConfigError> {
254        Ok(LlmConfig {
255            provider: self.provider.ok_or(LlmConfigError::MissingProvider)?,
256            base_url: self.base_url.ok_or(LlmConfigError::MissingBaseUrl)?,
257            model: self.model.ok_or(LlmConfigError::MissingModel)?,
258            api_key: self.api_key,
259            num_ctx: self.num_ctx.unwrap_or(4096),
260            max_tokens: self.max_tokens.unwrap_or(4096),
261        })
262    }
263}
264
265impl LlmRequestBuilder {
266    /// Set the system prompt.
267    pub fn system(mut self, msg: &str) -> Self {
268        self.system = Some(msg.to_string());
269        self
270    }
271
272    /// Append a user message.
273    pub fn user(mut self, msg: impl Into<String>) -> Self {
274        self.messages.push(msg.into());
275        self
276    }
277
278    /// Send the request and return the assistant's response text.
279    pub fn send(self) -> Result<String, StepError> {
280        let mut messages = Vec::new();
281
282        if let Some(sys) = &self.system {
283            messages.push(serde_json::json!({
284                "role": "system",
285                "content": sys
286            }));
287        }
288
289        for msg in &self.messages {
290            messages.push(serde_json::json!({
291                "role": "user",
292                "content": msg
293            }));
294        }
295
296        let body = match &self.config.provider {
297            Provider::Ollama => serde_json::json!({
298                "model": self.config.model,
299                "messages": messages,
300                "stream": false,
301                // Disable Qwen 3-style "thinking" tokens. Thinking models can
302                // otherwise spend minutes generating <think>...</think>
303                // reasoning before producing the actual response, which is
304                // rarely what an agentic workflow wants. Ignored by models
305                // that do not support thinking.
306                "think": false,
307                "options": {
308                    "num_ctx": self.config.num_ctx
309                }
310            }),
311            Provider::OpenAi => serde_json::json!({
312                "model": self.config.model,
313                "messages": messages,
314                "stream": false,
315                "max_tokens": self.config.max_tokens
316            }),
317            Provider::Anthropic => serde_json::json!({
318                "model": self.config.model,
319                "messages": messages,
320                "stream": false,
321                "max_tokens": self.config.max_tokens
322            }),
323        };
324
325        let url = self.config.provider.endpoint(&self.config.base_url);
326        let mut request = ureq::post(&url);
327
328        match &self.config.provider {
329            Provider::Anthropic => {
330                if let Some(key) = &self.config.api_key {
331                    request = request.header("x-api-key", key);
332                }
333                request = request.header("anthropic-version", "2023-06-01");
334                request = request.header("content-type", "application/json");
335            }
336            _ => {
337                if let Some(key) = &self.config.api_key {
338                    request = request.header("Authorization", &format!("Bearer {key}"));
339                }
340            }
341        }
342
343        if std::env::var("AGENT_LINE_DEBUG").is_ok() {
344            eprintln!("[debug] LLM request to {}", url);
345            eprintln!(
346                "[debug] Messages: {}",
347                serde_json::to_string_pretty(&messages).unwrap_or_default()
348            );
349        }
350
351        let mut response = request
352            .send_json(&body)
353            .map_err(|e| StepError::transient(format!("llm request failed: {e}")))?;
354
355        let json: serde_json::Value = response
356            .body_mut()
357            .read_json()
358            .map_err(|e| StepError::transient(format!("llm response parse failed: {e}")))?;
359
360        if std::env::var("AGENT_LINE_DEBUG").is_ok() {
361            eprintln!("[debug] LLM response: {}", &json);
362        }
363
364        self.config.provider.parse_response(&json)
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    // --- Provider::from_str ---
373
374    #[test]
375    fn test_provider_from_str_ollama() {
376        assert_eq!(Provider::from_str("ollama"), Provider::Ollama);
377    }
378
379    #[test]
380    fn test_provider_from_str_openai() {
381        assert_eq!(Provider::from_str("openai"), Provider::OpenAi);
382    }
383
384    #[test]
385    fn test_provider_from_str_anthropic() {
386        assert_eq!(Provider::from_str("anthropic"), Provider::Anthropic);
387    }
388
389    #[test]
390    fn test_provider_from_str_case_insensitive() {
391        assert_eq!(Provider::from_str("OpenAI"), Provider::OpenAi);
392        assert_eq!(Provider::from_str("ANTHROPIC"), Provider::Anthropic);
393        assert_eq!(Provider::from_str("Ollama"), Provider::Ollama);
394    }
395
396    #[test]
397    fn test_provider_from_str_unknown_defaults_to_ollama() {
398        assert_eq!(Provider::from_str("something"), Provider::Ollama);
399    }
400
401    // --- Provider::endpoint ---
402
403    #[test]
404    fn test_ollama_endpoint() {
405        assert_eq!(
406            Provider::Ollama.endpoint("http://localhost:11434"),
407            "http://localhost:11434/api/chat"
408        );
409    }
410
411    #[test]
412    fn test_openai_endpoint() {
413        assert_eq!(
414            Provider::OpenAi.endpoint("https://openrouter.ai"),
415            "https://openrouter.ai/v1/chat/completions"
416        );
417    }
418
419    #[test]
420    fn test_anthropic_endpoint() {
421        assert_eq!(
422            Provider::Anthropic.endpoint("https://api.anthropic.com"),
423            "https://api.anthropic.com/v1/messages"
424        );
425    }
426
427    #[test]
428    fn test_endpoint_strips_trailing_slash() {
429        assert_eq!(
430            Provider::OpenAi.endpoint("https://openrouter.ai/"),
431            "https://openrouter.ai/v1/chat/completions"
432        );
433    }
434
435    // --- Provider::parse_response ---
436
437    #[test]
438    fn test_ollama_parse_response() {
439        let json = serde_json::json!({
440            "message": { "content": "Hello from Ollama" }
441        });
442        assert_eq!(
443            Provider::Ollama.parse_response(&json).unwrap(),
444            "Hello from Ollama"
445        );
446    }
447
448    #[test]
449    fn test_openai_parse_response() {
450        let json = serde_json::json!({
451            "choices": [{ "message": { "content": "Hello from OpenRouter" } }]
452        });
453        assert_eq!(
454            Provider::OpenAi.parse_response(&json).unwrap(),
455            "Hello from OpenRouter"
456        );
457    }
458
459    #[test]
460    fn test_anthropic_parse_response() {
461        let json = serde_json::json!({
462            "content": [{ "text": "Hello from Claude" }]
463        });
464        assert_eq!(
465            Provider::Anthropic.parse_response(&json).unwrap(),
466            "Hello from Claude"
467        );
468    }
469
470    #[test]
471    fn test_parse_response_missing_content_is_error() {
472        let json = serde_json::json!({"unexpected": "shape"});
473        assert!(Provider::Ollama.parse_response(&json).is_err());
474        assert!(Provider::OpenAi.parse_response(&json).is_err());
475        assert!(Provider::Anthropic.parse_response(&json).is_err());
476    }
477
478    // --- LlmConfig builder ---
479
480    #[test]
481    fn llm_config_builder_happy_path() {
482        let config = LlmConfig::builder()
483            .provider(Provider::OpenAi)
484            .base_url("https://example.com")
485            .model("gpt-4")
486            .api_key("key")
487            .num_ctx(8192)
488            .max_tokens(2048)
489            .build()
490            .unwrap();
491
492        assert_eq!(config.provider, Provider::OpenAi);
493        assert_eq!(config.base_url, "https://example.com");
494        assert_eq!(config.model, "gpt-4");
495        assert_eq!(config.api_key.as_deref(), Some("key"));
496        assert_eq!(config.num_ctx, 8192);
497        assert_eq!(config.max_tokens, 2048);
498    }
499
500    #[test]
501    fn llm_config_builder_defaults_token_fields_to_4096() {
502        let config = LlmConfig::builder()
503            .provider(Provider::Ollama)
504            .base_url("http://localhost:11434")
505            .model("llama3")
506            .build()
507            .unwrap();
508
509        assert_eq!(config.num_ctx, 4096);
510        assert_eq!(config.max_tokens, 4096);
511    }
512
513    #[test]
514    fn llm_config_builder_api_key_optional() {
515        let config = LlmConfig::builder()
516            .provider(Provider::Ollama)
517            .base_url("http://localhost:11434")
518            .model("llama3")
519            .build()
520            .unwrap();
521
522        assert!(config.api_key.is_none());
523    }
524
525    #[test]
526    fn llm_config_builder_errors_without_provider() {
527        let err = LlmConfig::builder()
528            .base_url("http://localhost:11434")
529            .model("llama3")
530            .build()
531            .unwrap_err();
532
533        assert_eq!(err, LlmConfigError::MissingProvider);
534    }
535
536    #[test]
537    fn llm_config_builder_errors_without_base_url() {
538        let err = LlmConfig::builder()
539            .provider(Provider::Ollama)
540            .model("llama3")
541            .build()
542            .unwrap_err();
543
544        assert_eq!(err, LlmConfigError::MissingBaseUrl);
545    }
546
547    #[test]
548    fn llm_config_builder_errors_without_model() {
549        let err = LlmConfig::builder()
550            .provider(Provider::Ollama)
551            .base_url("http://localhost:11434")
552            .build()
553            .unwrap_err();
554
555        assert_eq!(err, LlmConfigError::MissingModel);
556    }
557
558    #[test]
559    fn request_uses_owned_config() {
560        let cfg = LlmConfig::builder()
561            .provider(Provider::Ollama)
562            .base_url("http://localhost:11434")
563            .model("llama3")
564            .build()
565            .unwrap();
566
567        let req = cfg.request().system("hi").user("hello");
568
569        assert_eq!(req.config.model, "llama3");
570        assert_eq!(req.config.provider, Provider::Ollama);
571        assert_eq!(req.config.base_url, "http://localhost:11434");
572    }
573
574    #[test]
575    fn request_can_be_called_repeatedly_on_same_config() {
576        let cfg = LlmConfig::builder()
577            .provider(Provider::Ollama)
578            .base_url("http://localhost:11434")
579            .model("llama3")
580            .build()
581            .unwrap();
582
583        // Each request() returns its own builder; the config is not consumed.
584        let r1 = cfg.request().user("first");
585        let r2 = cfg.request().user("second");
586
587        assert_eq!(r1.messages, vec!["first".to_string()]);
588        assert_eq!(r2.messages, vec!["second".to_string()]);
589        assert_eq!(r1.config.model, "llama3");
590        assert_eq!(r2.config.model, "llama3");
591    }
592}