1use crate::agent::StepError;
2use std::{env, fmt, sync::Arc};
3
4#[derive(Clone, PartialEq, Eq)]
12pub struct LlmConfig {
13 base_url: String,
14 model: String,
15 num_ctx: u32,
16 max_tokens: u32,
17 api_key: Option<String>,
18 provider: Provider,
19}
20
21impl fmt::Debug for LlmConfig {
22 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23 f.debug_struct("LlmConfig")
24 .field("provider", &self.provider)
25 .field("base_url", &self.base_url)
26 .field("model", &self.model)
27 .field("num_ctx", &self.num_ctx)
28 .field("max_tokens", &self.max_tokens)
29 .field(
30 "api_key",
31 &if self.api_key.is_some() {
32 "set"
33 } else {
34 "not set"
35 },
36 )
37 .finish()
38 }
39}
40
41#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum LlmConfigError {
44 MissingProvider,
46 MissingBaseUrl,
48 MissingModel,
50}
51
52impl fmt::Display for LlmConfigError {
53 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 match self {
55 Self::MissingProvider => write!(f, "LlmConfig missing provider"),
56 Self::MissingBaseUrl => write!(f, "LlmConfig missing base_url"),
57 Self::MissingModel => write!(f, "LlmConfig missing model"),
58 }
59 }
60}
61
62impl std::error::Error for LlmConfigError {}
63
64#[derive(Default)]
66pub struct LlmConfigBuilder {
67 provider: Option<Provider>,
68 base_url: Option<String>,
69 model: Option<String>,
70 api_key: Option<String>,
71 num_ctx: Option<u32>,
72 max_tokens: Option<u32>,
73}
74
75pub struct LlmRequestBuilder {
77 config: Arc<LlmConfig>,
78 system: Option<String>,
79 messages: Vec<String>,
80}
81
82#[non_exhaustive]
85#[derive(Clone, Copy, Debug, PartialEq, Eq)]
86pub enum Provider {
87 Ollama,
89 OpenAi,
91 Anthropic,
93}
94
95impl Provider {
96 pub(crate) fn from_str(s: &str) -> Self {
98 match s.to_lowercase().as_str() {
99 "openai" => Provider::OpenAi,
100 "anthropic" => Provider::Anthropic,
101 _ => Provider::Ollama,
102 }
103 }
104
105 pub(crate) fn endpoint(&self, base_url: &str) -> String {
106 let base = base_url.trim_end_matches('/');
107 match self {
108 Provider::Ollama => format!("{base}/api/chat"),
109 Provider::OpenAi => format!("{base}/v1/chat/completions"),
110 Provider::Anthropic => format!("{base}/v1/messages"),
111 }
112 }
113
114 pub(crate) fn parse_response(&self, json: &serde_json::Value) -> Result<String, StepError> {
115 let content = match self {
116 Provider::Ollama => json["message"]["content"].as_str(),
117 Provider::OpenAi => json["choices"][0]["message"]["content"].as_str(),
118 Provider::Anthropic => json["content"][0]["text"].as_str(),
119 };
120 content
121 .map(|s| s.to_string())
122 .ok_or_else(|| StepError::other("llm response missing message content"))
123 }
124}
125
126impl LlmConfig {
127 pub fn builder() -> LlmConfigBuilder {
129 LlmConfigBuilder::default()
130 }
131
132 pub fn from_env() -> Self {
143 let num_ctx = match env::var("AGENT_LINE_NUM_CTX") {
144 Ok(v) => v.parse::<u32>().unwrap_or(4096),
145 Err(_) => 4096,
146 };
147 let max_tokens = match env::var("AGENT_LINE_MAX_TOKENS") {
148 Ok(v) => v.parse::<u32>().unwrap_or(num_ctx),
149 Err(_) => num_ctx,
150 };
151
152 let config = Self {
153 provider: Provider::from_str(
154 &env::var("AGENT_LINE_PROVIDER").unwrap_or_else(|_| "ollama".to_string()),
155 ),
156 base_url: env::var("AGENT_LINE_LLM_URL")
157 .unwrap_or_else(|_| "http://localhost:11434".to_string()),
158 model: env::var("AGENT_LINE_MODEL").unwrap_or_else(|_| "llama3.1:8b".to_string()),
159 api_key: env::var("AGENT_LINE_API_KEY").ok(),
160 num_ctx,
161 max_tokens,
162 };
163 config.debug_log();
164 config
165 }
166
167 pub fn with_model(mut self, model: impl Into<String>) -> Self {
170 self.model = model.into();
171 self
172 }
173
174 pub fn request(&self) -> LlmRequestBuilder {
180 LlmRequestBuilder {
181 config: Arc::new(self.clone()),
182 system: None,
183 messages: Vec::new(),
184 }
185 }
186
187 fn debug_log(&self) {
188 if env::var("AGENT_LINE_DEBUG").is_ok() {
189 eprintln!(
190 "[debug] provider: {:?}\n\
191 [debug] model: {}\n\
192 [debug] base_url: {}\n\
193 [debug] num_ctx: {}\n\
194 [debug] max_tokens: {}\n\
195 [debug] api_key: {}",
196 self.provider,
197 self.model,
198 self.base_url,
199 self.num_ctx,
200 self.max_tokens,
201 if self.api_key.is_some() {
202 "set"
203 } else {
204 "not set"
205 },
206 );
207 }
208 }
209}
210
211impl LlmConfigBuilder {
212 pub fn provider(mut self, provider: Provider) -> Self {
214 self.provider = Some(provider);
215 self
216 }
217
218 pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
220 self.base_url = Some(base_url.into());
221 self
222 }
223
224 pub fn model(mut self, model: impl Into<String>) -> Self {
226 self.model = Some(model.into());
227 self
228 }
229
230 pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
232 self.api_key = Some(api_key.into());
233 self
234 }
235
236 pub fn num_ctx(mut self, num_ctx: u32) -> Self {
240 self.num_ctx = Some(num_ctx);
241 self
242 }
243
244 pub fn max_tokens(mut self, max_tokens: u32) -> Self {
248 self.max_tokens = Some(max_tokens);
249 self
250 }
251
252 pub fn build(self) -> Result<LlmConfig, LlmConfigError> {
254 Ok(LlmConfig {
255 provider: self.provider.ok_or(LlmConfigError::MissingProvider)?,
256 base_url: self.base_url.ok_or(LlmConfigError::MissingBaseUrl)?,
257 model: self.model.ok_or(LlmConfigError::MissingModel)?,
258 api_key: self.api_key,
259 num_ctx: self.num_ctx.unwrap_or(4096),
260 max_tokens: self.max_tokens.unwrap_or(4096),
261 })
262 }
263}
264
265impl LlmRequestBuilder {
266 pub fn system(mut self, msg: &str) -> Self {
268 self.system = Some(msg.to_string());
269 self
270 }
271
272 pub fn user(mut self, msg: impl Into<String>) -> Self {
274 self.messages.push(msg.into());
275 self
276 }
277
278 pub fn send(self) -> Result<String, StepError> {
280 let mut messages = Vec::new();
281
282 if let Some(sys) = &self.system {
283 messages.push(serde_json::json!({
284 "role": "system",
285 "content": sys
286 }));
287 }
288
289 for msg in &self.messages {
290 messages.push(serde_json::json!({
291 "role": "user",
292 "content": msg
293 }));
294 }
295
296 let body = match &self.config.provider {
297 Provider::Ollama => serde_json::json!({
298 "model": self.config.model,
299 "messages": messages,
300 "stream": false,
301 "think": false,
307 "options": {
308 "num_ctx": self.config.num_ctx
309 }
310 }),
311 Provider::OpenAi => serde_json::json!({
312 "model": self.config.model,
313 "messages": messages,
314 "stream": false,
315 "max_tokens": self.config.max_tokens
316 }),
317 Provider::Anthropic => serde_json::json!({
318 "model": self.config.model,
319 "messages": messages,
320 "stream": false,
321 "max_tokens": self.config.max_tokens
322 }),
323 };
324
325 let url = self.config.provider.endpoint(&self.config.base_url);
326 let mut request = ureq::post(&url);
327
328 match &self.config.provider {
329 Provider::Anthropic => {
330 if let Some(key) = &self.config.api_key {
331 request = request.header("x-api-key", key);
332 }
333 request = request.header("anthropic-version", "2023-06-01");
334 request = request.header("content-type", "application/json");
335 }
336 _ => {
337 if let Some(key) = &self.config.api_key {
338 request = request.header("Authorization", &format!("Bearer {key}"));
339 }
340 }
341 }
342
343 if std::env::var("AGENT_LINE_DEBUG").is_ok() {
344 eprintln!("[debug] LLM request to {}", url);
345 eprintln!(
346 "[debug] Messages: {}",
347 serde_json::to_string_pretty(&messages).unwrap_or_default()
348 );
349 }
350
351 let mut response = request
352 .send_json(&body)
353 .map_err(|e| StepError::transient(format!("llm request failed: {e}")))?;
354
355 let json: serde_json::Value = response
356 .body_mut()
357 .read_json()
358 .map_err(|e| StepError::transient(format!("llm response parse failed: {e}")))?;
359
360 if std::env::var("AGENT_LINE_DEBUG").is_ok() {
361 eprintln!("[debug] LLM response: {}", &json);
362 }
363
364 self.config.provider.parse_response(&json)
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[test]
375 fn test_provider_from_str_ollama() {
376 assert_eq!(Provider::from_str("ollama"), Provider::Ollama);
377 }
378
379 #[test]
380 fn test_provider_from_str_openai() {
381 assert_eq!(Provider::from_str("openai"), Provider::OpenAi);
382 }
383
384 #[test]
385 fn test_provider_from_str_anthropic() {
386 assert_eq!(Provider::from_str("anthropic"), Provider::Anthropic);
387 }
388
389 #[test]
390 fn test_provider_from_str_case_insensitive() {
391 assert_eq!(Provider::from_str("OpenAI"), Provider::OpenAi);
392 assert_eq!(Provider::from_str("ANTHROPIC"), Provider::Anthropic);
393 assert_eq!(Provider::from_str("Ollama"), Provider::Ollama);
394 }
395
396 #[test]
397 fn test_provider_from_str_unknown_defaults_to_ollama() {
398 assert_eq!(Provider::from_str("something"), Provider::Ollama);
399 }
400
401 #[test]
404 fn test_ollama_endpoint() {
405 assert_eq!(
406 Provider::Ollama.endpoint("http://localhost:11434"),
407 "http://localhost:11434/api/chat"
408 );
409 }
410
411 #[test]
412 fn test_openai_endpoint() {
413 assert_eq!(
414 Provider::OpenAi.endpoint("https://openrouter.ai"),
415 "https://openrouter.ai/v1/chat/completions"
416 );
417 }
418
419 #[test]
420 fn test_anthropic_endpoint() {
421 assert_eq!(
422 Provider::Anthropic.endpoint("https://api.anthropic.com"),
423 "https://api.anthropic.com/v1/messages"
424 );
425 }
426
427 #[test]
428 fn test_endpoint_strips_trailing_slash() {
429 assert_eq!(
430 Provider::OpenAi.endpoint("https://openrouter.ai/"),
431 "https://openrouter.ai/v1/chat/completions"
432 );
433 }
434
435 #[test]
438 fn test_ollama_parse_response() {
439 let json = serde_json::json!({
440 "message": { "content": "Hello from Ollama" }
441 });
442 assert_eq!(
443 Provider::Ollama.parse_response(&json).unwrap(),
444 "Hello from Ollama"
445 );
446 }
447
448 #[test]
449 fn test_openai_parse_response() {
450 let json = serde_json::json!({
451 "choices": [{ "message": { "content": "Hello from OpenRouter" } }]
452 });
453 assert_eq!(
454 Provider::OpenAi.parse_response(&json).unwrap(),
455 "Hello from OpenRouter"
456 );
457 }
458
459 #[test]
460 fn test_anthropic_parse_response() {
461 let json = serde_json::json!({
462 "content": [{ "text": "Hello from Claude" }]
463 });
464 assert_eq!(
465 Provider::Anthropic.parse_response(&json).unwrap(),
466 "Hello from Claude"
467 );
468 }
469
470 #[test]
471 fn test_parse_response_missing_content_is_error() {
472 let json = serde_json::json!({"unexpected": "shape"});
473 assert!(Provider::Ollama.parse_response(&json).is_err());
474 assert!(Provider::OpenAi.parse_response(&json).is_err());
475 assert!(Provider::Anthropic.parse_response(&json).is_err());
476 }
477
478 #[test]
481 fn llm_config_builder_happy_path() {
482 let config = LlmConfig::builder()
483 .provider(Provider::OpenAi)
484 .base_url("https://example.com")
485 .model("gpt-4")
486 .api_key("key")
487 .num_ctx(8192)
488 .max_tokens(2048)
489 .build()
490 .unwrap();
491
492 assert_eq!(config.provider, Provider::OpenAi);
493 assert_eq!(config.base_url, "https://example.com");
494 assert_eq!(config.model, "gpt-4");
495 assert_eq!(config.api_key.as_deref(), Some("key"));
496 assert_eq!(config.num_ctx, 8192);
497 assert_eq!(config.max_tokens, 2048);
498 }
499
500 #[test]
501 fn llm_config_builder_defaults_token_fields_to_4096() {
502 let config = LlmConfig::builder()
503 .provider(Provider::Ollama)
504 .base_url("http://localhost:11434")
505 .model("llama3")
506 .build()
507 .unwrap();
508
509 assert_eq!(config.num_ctx, 4096);
510 assert_eq!(config.max_tokens, 4096);
511 }
512
513 #[test]
514 fn llm_config_builder_api_key_optional() {
515 let config = LlmConfig::builder()
516 .provider(Provider::Ollama)
517 .base_url("http://localhost:11434")
518 .model("llama3")
519 .build()
520 .unwrap();
521
522 assert!(config.api_key.is_none());
523 }
524
525 #[test]
526 fn llm_config_builder_errors_without_provider() {
527 let err = LlmConfig::builder()
528 .base_url("http://localhost:11434")
529 .model("llama3")
530 .build()
531 .unwrap_err();
532
533 assert_eq!(err, LlmConfigError::MissingProvider);
534 }
535
536 #[test]
537 fn llm_config_builder_errors_without_base_url() {
538 let err = LlmConfig::builder()
539 .provider(Provider::Ollama)
540 .model("llama3")
541 .build()
542 .unwrap_err();
543
544 assert_eq!(err, LlmConfigError::MissingBaseUrl);
545 }
546
547 #[test]
548 fn llm_config_builder_errors_without_model() {
549 let err = LlmConfig::builder()
550 .provider(Provider::Ollama)
551 .base_url("http://localhost:11434")
552 .build()
553 .unwrap_err();
554
555 assert_eq!(err, LlmConfigError::MissingModel);
556 }
557
558 #[test]
559 fn request_uses_owned_config() {
560 let cfg = LlmConfig::builder()
561 .provider(Provider::Ollama)
562 .base_url("http://localhost:11434")
563 .model("llama3")
564 .build()
565 .unwrap();
566
567 let req = cfg.request().system("hi").user("hello");
568
569 assert_eq!(req.config.model, "llama3");
570 assert_eq!(req.config.provider, Provider::Ollama);
571 assert_eq!(req.config.base_url, "http://localhost:11434");
572 }
573
574 #[test]
575 fn request_can_be_called_repeatedly_on_same_config() {
576 let cfg = LlmConfig::builder()
577 .provider(Provider::Ollama)
578 .base_url("http://localhost:11434")
579 .model("llama3")
580 .build()
581 .unwrap();
582
583 let r1 = cfg.request().user("first");
585 let r2 = cfg.request().user("second");
586
587 assert_eq!(r1.messages, vec!["first".to_string()]);
588 assert_eq!(r2.messages, vec!["second".to_string()]);
589 assert_eq!(r1.config.model, "llama3");
590 assert_eq!(r2.config.model, "llama3");
591 }
592}