1use serde::{Deserialize, Serialize};
54use stakai::Model;
55use std::collections::HashMap;
56
57#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
89#[serde(tag = "type", rename_all = "lowercase")]
90pub enum ProviderConfig {
91 OpenAI {
93 #[serde(skip_serializing_if = "Option::is_none")]
94 api_key: Option<String>,
95 #[serde(skip_serializing_if = "Option::is_none")]
96 api_endpoint: Option<String>,
97 },
98 Anthropic {
100 #[serde(skip_serializing_if = "Option::is_none")]
101 api_key: Option<String>,
102 #[serde(skip_serializing_if = "Option::is_none")]
103 api_endpoint: Option<String>,
104 #[serde(skip_serializing_if = "Option::is_none")]
106 access_token: Option<String>,
107 },
108 Gemini {
110 #[serde(skip_serializing_if = "Option::is_none")]
111 api_key: Option<String>,
112 #[serde(skip_serializing_if = "Option::is_none")]
113 api_endpoint: Option<String>,
114 },
115 Custom {
133 #[serde(skip_serializing_if = "Option::is_none")]
134 api_key: Option<String>,
135 api_endpoint: String,
138 },
139 Stakpak {
157 api_key: String,
159 #[serde(skip_serializing_if = "Option::is_none")]
161 api_endpoint: Option<String>,
162 },
163}
164
165impl ProviderConfig {
166 pub fn provider_type(&self) -> &'static str {
168 match self {
169 ProviderConfig::OpenAI { .. } => "openai",
170 ProviderConfig::Anthropic { .. } => "anthropic",
171 ProviderConfig::Gemini { .. } => "gemini",
172 ProviderConfig::Custom { .. } => "custom",
173 ProviderConfig::Stakpak { .. } => "stakpak",
174 }
175 }
176
177 pub fn api_key(&self) -> Option<&str> {
179 match self {
180 ProviderConfig::OpenAI { api_key, .. } => api_key.as_deref(),
181 ProviderConfig::Anthropic { api_key, .. } => api_key.as_deref(),
182 ProviderConfig::Gemini { api_key, .. } => api_key.as_deref(),
183 ProviderConfig::Custom { api_key, .. } => api_key.as_deref(),
184 ProviderConfig::Stakpak { api_key, .. } => Some(api_key.as_str()),
185 }
186 }
187
188 pub fn api_endpoint(&self) -> Option<&str> {
190 match self {
191 ProviderConfig::OpenAI { api_endpoint, .. } => api_endpoint.as_deref(),
192 ProviderConfig::Anthropic { api_endpoint, .. } => api_endpoint.as_deref(),
193 ProviderConfig::Gemini { api_endpoint, .. } => api_endpoint.as_deref(),
194 ProviderConfig::Custom { api_endpoint, .. } => Some(api_endpoint.as_str()),
195 ProviderConfig::Stakpak { api_endpoint, .. } => api_endpoint.as_deref(),
196 }
197 }
198
199 pub fn access_token(&self) -> Option<&str> {
201 match self {
202 ProviderConfig::Anthropic { access_token, .. } => access_token.as_deref(),
203 _ => None,
204 }
205 }
206
207 pub fn openai(api_key: Option<String>) -> Self {
209 ProviderConfig::OpenAI {
210 api_key,
211 api_endpoint: None,
212 }
213 }
214
215 pub fn anthropic(api_key: Option<String>, access_token: Option<String>) -> Self {
217 ProviderConfig::Anthropic {
218 api_key,
219 api_endpoint: None,
220 access_token,
221 }
222 }
223
224 pub fn gemini(api_key: Option<String>) -> Self {
226 ProviderConfig::Gemini {
227 api_key,
228 api_endpoint: None,
229 }
230 }
231
232 pub fn custom(api_endpoint: String, api_key: Option<String>) -> Self {
234 ProviderConfig::Custom {
235 api_key,
236 api_endpoint,
237 }
238 }
239
240 pub fn stakpak(api_key: String, api_endpoint: Option<String>) -> Self {
242 ProviderConfig::Stakpak {
243 api_key,
244 api_endpoint,
245 }
246 }
247}
248
249#[derive(Debug, Clone, Default)]
253pub struct LLMProviderConfig {
254 pub providers: HashMap<String, ProviderConfig>,
256}
257
258impl LLMProviderConfig {
259 pub fn new() -> Self {
261 Self {
262 providers: HashMap::new(),
263 }
264 }
265
266 pub fn add_provider(&mut self, name: impl Into<String>, config: ProviderConfig) {
268 self.providers.insert(name.into(), config);
269 }
270
271 pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
273 self.providers.get(name)
274 }
275
276 pub fn is_empty(&self) -> bool {
278 self.providers.is_empty()
279 }
280}
281
282#[derive(Clone, Debug, Serialize, Deserialize, Default)]
284pub struct LLMProviderOptions {
285 #[serde(skip_serializing_if = "Option::is_none")]
287 pub anthropic: Option<LLMAnthropicOptions>,
288
289 #[serde(skip_serializing_if = "Option::is_none")]
291 pub openai: Option<LLMOpenAIOptions>,
292
293 #[serde(skip_serializing_if = "Option::is_none")]
295 pub google: Option<LLMGoogleOptions>,
296}
297
298#[derive(Clone, Debug, Serialize, Deserialize, Default)]
300pub struct LLMAnthropicOptions {
301 #[serde(skip_serializing_if = "Option::is_none")]
303 pub thinking: Option<LLMThinkingOptions>,
304}
305
306#[derive(Clone, Debug, Serialize, Deserialize)]
308pub struct LLMThinkingOptions {
309 pub budget_tokens: u32,
311}
312
313impl LLMThinkingOptions {
314 pub fn new(budget_tokens: u32) -> Self {
315 Self {
316 budget_tokens: budget_tokens.max(1024),
317 }
318 }
319}
320
321#[derive(Clone, Debug, Serialize, Deserialize, Default)]
323pub struct LLMOpenAIOptions {
324 #[serde(skip_serializing_if = "Option::is_none")]
326 pub reasoning_effort: Option<String>,
327}
328
329#[derive(Clone, Debug, Serialize, Deserialize, Default)]
331pub struct LLMGoogleOptions {
332 #[serde(skip_serializing_if = "Option::is_none")]
334 pub thinking_budget: Option<u32>,
335}
336
337#[derive(Clone, Debug, Serialize)]
338pub struct LLMInput {
339 pub model: Model,
340 pub messages: Vec<LLMMessage>,
341 pub max_tokens: u32,
342 pub tools: Option<Vec<LLMTool>>,
343 #[serde(skip_serializing_if = "Option::is_none")]
344 pub provider_options: Option<LLMProviderOptions>,
345 #[serde(skip_serializing_if = "Option::is_none")]
347 pub headers: Option<std::collections::HashMap<String, String>>,
348}
349
350#[derive(Debug)]
351pub struct LLMStreamInput {
352 pub model: Model,
353 pub messages: Vec<LLMMessage>,
354 pub max_tokens: u32,
355 pub stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
356 pub tools: Option<Vec<LLMTool>>,
357 pub provider_options: Option<LLMProviderOptions>,
358 pub headers: Option<std::collections::HashMap<String, String>>,
360}
361
362impl From<&LLMStreamInput> for LLMInput {
363 fn from(value: &LLMStreamInput) -> Self {
364 LLMInput {
365 model: value.model.clone(),
366 messages: value.messages.clone(),
367 max_tokens: value.max_tokens,
368 tools: value.tools.clone(),
369 provider_options: value.provider_options.clone(),
370 headers: value.headers.clone(),
371 }
372 }
373}
374
375#[derive(Serialize, Deserialize, Debug, Clone, Default)]
376pub struct LLMMessage {
377 pub role: String,
378 pub content: LLMMessageContent,
379}
380
381#[derive(Serialize, Deserialize, Debug, Clone)]
382pub struct SimpleLLMMessage {
383 #[serde(rename = "role")]
384 pub role: SimpleLLMRole,
385 pub content: String,
386}
387
388#[derive(Serialize, Deserialize, Debug, Clone)]
389#[serde(rename_all = "lowercase")]
390pub enum SimpleLLMRole {
391 User,
392 Assistant,
393}
394
395impl std::fmt::Display for SimpleLLMRole {
396 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
397 match self {
398 SimpleLLMRole::User => write!(f, "user"),
399 SimpleLLMRole::Assistant => write!(f, "assistant"),
400 }
401 }
402}
403
404#[derive(Serialize, Deserialize, Debug, Clone)]
405#[serde(untagged)]
406pub enum LLMMessageContent {
407 String(String),
408 List(Vec<LLMMessageTypedContent>),
409}
410
411#[allow(clippy::to_string_trait_impl)]
412impl ToString for LLMMessageContent {
413 fn to_string(&self) -> String {
414 match self {
415 LLMMessageContent::String(s) => s.clone(),
416 LLMMessageContent::List(l) => l
417 .iter()
418 .map(|c| match c {
419 LLMMessageTypedContent::Text { text } => text.clone(),
420 LLMMessageTypedContent::ToolCall { .. } => String::new(),
421 LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
422 LLMMessageTypedContent::Image { .. } => String::new(),
423 })
424 .collect::<Vec<_>>()
425 .join("\n"),
426 }
427 }
428}
429
430impl From<String> for LLMMessageContent {
431 fn from(value: String) -> Self {
432 LLMMessageContent::String(value)
433 }
434}
435
436impl Default for LLMMessageContent {
437 fn default() -> Self {
438 LLMMessageContent::String(String::new())
439 }
440}
441
442impl LLMMessageContent {
443 pub fn into_parts(self) -> Vec<LLMMessageTypedContent> {
446 match self {
447 LLMMessageContent::List(parts) => parts,
448 LLMMessageContent::String(s) if s.is_empty() => vec![],
449 LLMMessageContent::String(s) => vec![LLMMessageTypedContent::Text { text: s }],
450 }
451 }
452}
453
454#[derive(Serialize, Deserialize, Debug, Clone)]
455#[serde(tag = "type")]
456pub enum LLMMessageTypedContent {
457 #[serde(rename = "text")]
458 Text { text: String },
459 #[serde(rename = "tool_use")]
460 ToolCall {
461 id: String,
462 name: String,
463 #[serde(alias = "input")]
464 args: serde_json::Value,
465 #[serde(skip_serializing_if = "Option::is_none")]
467 metadata: Option<serde_json::Value>,
468 },
469 #[serde(rename = "tool_result")]
470 ToolResult {
471 tool_use_id: String,
472 content: String,
473 },
474 #[serde(rename = "image")]
475 Image { source: LLMMessageImageSource },
476}
477
478#[derive(Serialize, Deserialize, Debug, Clone)]
479pub struct LLMMessageImageSource {
480 #[serde(rename = "type")]
481 pub r#type: String,
482 pub media_type: String,
483 pub data: String,
484}
485
486impl Default for LLMMessageTypedContent {
487 fn default() -> Self {
488 LLMMessageTypedContent::Text {
489 text: String::new(),
490 }
491 }
492}
493
494#[derive(Serialize, Deserialize, Debug, Clone)]
495pub struct LLMChoice {
496 pub finish_reason: Option<String>,
497 pub index: u32,
498 pub message: LLMMessage,
499}
500
501#[derive(Serialize, Deserialize, Debug, Clone)]
502pub struct LLMCompletionResponse {
503 pub model: String,
504 pub object: String,
505 pub choices: Vec<LLMChoice>,
506 pub created: u64,
507 pub usage: Option<LLMTokenUsage>,
508 pub id: String,
509}
510
511#[derive(Serialize, Deserialize, Debug, Clone)]
512pub struct LLMStreamDelta {
513 #[serde(skip_serializing_if = "Option::is_none")]
514 pub content: Option<String>,
515}
516
517#[derive(Serialize, Deserialize, Debug, Clone)]
518pub struct LLMStreamChoice {
519 pub finish_reason: Option<String>,
520 pub index: u32,
521 pub message: Option<LLMMessage>,
522 pub delta: LLMStreamDelta,
523}
524
525#[derive(Serialize, Deserialize, Debug, Clone)]
526pub struct LLMCompletionStreamResponse {
527 pub model: String,
528 pub object: String,
529 pub choices: Vec<LLMStreamChoice>,
530 pub created: u64,
531 #[serde(skip_serializing_if = "Option::is_none")]
532 pub usage: Option<LLMTokenUsage>,
533 pub id: String,
534 pub citations: Option<Vec<String>>,
535}
536
537#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
538pub struct LLMTool {
539 pub name: String,
540 pub description: String,
541 pub input_schema: serde_json::Value,
542}
543
544#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)]
545pub struct LLMTokenUsage {
546 pub prompt_tokens: u32,
547 pub completion_tokens: u32,
548 pub total_tokens: u32,
549
550 #[serde(skip_serializing_if = "Option::is_none")]
551 pub prompt_tokens_details: Option<PromptTokensDetails>,
552}
553
554#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
555#[serde(rename_all = "snake_case")]
556pub enum TokenType {
557 InputTokens,
558 OutputTokens,
559 CacheReadInputTokens,
560 CacheWriteInputTokens,
561}
562
563#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
564pub struct PromptTokensDetails {
565 #[serde(skip_serializing_if = "Option::is_none")]
566 pub input_tokens: Option<u32>,
567 #[serde(skip_serializing_if = "Option::is_none")]
568 pub output_tokens: Option<u32>,
569 #[serde(skip_serializing_if = "Option::is_none")]
570 pub cache_read_input_tokens: Option<u32>,
571 #[serde(skip_serializing_if = "Option::is_none")]
572 pub cache_write_input_tokens: Option<u32>,
573}
574
575impl PromptTokensDetails {
576 pub fn iter(&self) -> impl Iterator<Item = (TokenType, u32)> {
578 [
579 (TokenType::InputTokens, self.input_tokens.unwrap_or(0)),
580 (TokenType::OutputTokens, self.output_tokens.unwrap_or(0)),
581 (
582 TokenType::CacheReadInputTokens,
583 self.cache_read_input_tokens.unwrap_or(0),
584 ),
585 (
586 TokenType::CacheWriteInputTokens,
587 self.cache_write_input_tokens.unwrap_or(0),
588 ),
589 ]
590 .into_iter()
591 }
592}
593
594impl std::ops::Add for PromptTokensDetails {
595 type Output = Self;
596
597 fn add(self, rhs: Self) -> Self::Output {
598 Self {
599 input_tokens: Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0)),
600 output_tokens: Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0)),
601 cache_read_input_tokens: Some(
602 self.cache_read_input_tokens.unwrap_or(0)
603 + rhs.cache_read_input_tokens.unwrap_or(0),
604 ),
605 cache_write_input_tokens: Some(
606 self.cache_write_input_tokens.unwrap_or(0)
607 + rhs.cache_write_input_tokens.unwrap_or(0),
608 ),
609 }
610 }
611}
612
613impl std::ops::AddAssign for PromptTokensDetails {
614 fn add_assign(&mut self, rhs: Self) {
615 self.input_tokens = Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0));
616 self.output_tokens = Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0));
617 self.cache_read_input_tokens = Some(
618 self.cache_read_input_tokens.unwrap_or(0) + rhs.cache_read_input_tokens.unwrap_or(0),
619 );
620 self.cache_write_input_tokens = Some(
621 self.cache_write_input_tokens.unwrap_or(0) + rhs.cache_write_input_tokens.unwrap_or(0),
622 );
623 }
624}
625
626#[derive(Serialize, Deserialize, Debug, Clone)]
627#[serde(tag = "type")]
628pub enum GenerationDelta {
629 Content { content: String },
630 Thinking { thinking: String },
631 ToolUse { tool_use: GenerationDeltaToolUse },
632 Usage { usage: LLMTokenUsage },
633 Metadata { metadata: serde_json::Value },
634}
635
636#[derive(Serialize, Deserialize, Debug, Clone)]
637pub struct GenerationDeltaToolUse {
638 pub id: Option<String>,
639 pub name: Option<String>,
640 pub input: Option<String>,
641 pub index: usize,
642 #[serde(skip_serializing_if = "Option::is_none")]
644 pub metadata: Option<serde_json::Value>,
645}
646
647#[cfg(test)]
648mod tests {
649 use super::*;
650
651 #[test]
656 fn test_provider_config_openai_serialization() {
657 let config = ProviderConfig::OpenAI {
658 api_key: Some("sk-test".to_string()),
659 api_endpoint: None,
660 };
661 let json = serde_json::to_string(&config).unwrap();
662 assert!(json.contains("\"type\":\"openai\""));
663 assert!(json.contains("\"api_key\":\"sk-test\""));
664 assert!(!json.contains("api_endpoint")); }
666
667 #[test]
668 fn test_provider_config_openai_with_endpoint() {
669 let config = ProviderConfig::OpenAI {
670 api_key: Some("sk-test".to_string()),
671 api_endpoint: Some("https://custom.openai.com/v1".to_string()),
672 };
673 let json = serde_json::to_string(&config).unwrap();
674 assert!(json.contains("\"api_endpoint\":\"https://custom.openai.com/v1\""));
675 }
676
677 #[test]
678 fn test_provider_config_anthropic_serialization() {
679 let config = ProviderConfig::Anthropic {
680 api_key: Some("sk-ant-test".to_string()),
681 api_endpoint: None,
682 access_token: Some("oauth-token".to_string()),
683 };
684 let json = serde_json::to_string(&config).unwrap();
685 assert!(json.contains("\"type\":\"anthropic\""));
686 assert!(json.contains("\"api_key\":\"sk-ant-test\""));
687 assert!(json.contains("\"access_token\":\"oauth-token\""));
688 }
689
690 #[test]
691 fn test_provider_config_gemini_serialization() {
692 let config = ProviderConfig::Gemini {
693 api_key: Some("gemini-key".to_string()),
694 api_endpoint: None,
695 };
696 let json = serde_json::to_string(&config).unwrap();
697 assert!(json.contains("\"type\":\"gemini\""));
698 assert!(json.contains("\"api_key\":\"gemini-key\""));
699 }
700
701 #[test]
702 fn test_provider_config_custom_serialization() {
703 let config = ProviderConfig::Custom {
704 api_key: Some("sk-custom".to_string()),
705 api_endpoint: "http://localhost:4000".to_string(),
706 };
707 let json = serde_json::to_string(&config).unwrap();
708 assert!(json.contains("\"type\":\"custom\""));
709 assert!(json.contains("\"api_endpoint\":\"http://localhost:4000\""));
710 assert!(json.contains("\"api_key\":\"sk-custom\""));
711 }
712
713 #[test]
714 fn test_provider_config_custom_without_key() {
715 let config = ProviderConfig::Custom {
716 api_key: None,
717 api_endpoint: "http://localhost:11434/v1".to_string(),
718 };
719 let json = serde_json::to_string(&config).unwrap();
720 assert!(json.contains("\"type\":\"custom\""));
721 assert!(json.contains("\"api_endpoint\""));
722 assert!(!json.contains("api_key")); }
724
725 #[test]
726 fn test_provider_config_deserialization_openai() {
727 let json = r#"{"type":"openai","api_key":"sk-test"}"#;
728 let config: ProviderConfig = serde_json::from_str(json).unwrap();
729 assert!(matches!(config, ProviderConfig::OpenAI { .. }));
730 assert_eq!(config.api_key(), Some("sk-test"));
731 }
732
733 #[test]
734 fn test_provider_config_deserialization_anthropic() {
735 let json = r#"{"type":"anthropic","api_key":"sk-ant","access_token":"oauth"}"#;
736 let config: ProviderConfig = serde_json::from_str(json).unwrap();
737 assert!(matches!(config, ProviderConfig::Anthropic { .. }));
738 assert_eq!(config.api_key(), Some("sk-ant"));
739 assert_eq!(config.access_token(), Some("oauth"));
740 }
741
742 #[test]
743 fn test_provider_config_deserialization_gemini() {
744 let json = r#"{"type":"gemini","api_key":"gemini-key"}"#;
745 let config: ProviderConfig = serde_json::from_str(json).unwrap();
746 assert!(matches!(config, ProviderConfig::Gemini { .. }));
747 assert_eq!(config.api_key(), Some("gemini-key"));
748 }
749
750 #[test]
751 fn test_provider_config_deserialization_custom() {
752 let json =
753 r#"{"type":"custom","api_endpoint":"http://localhost:4000","api_key":"sk-custom"}"#;
754 let config: ProviderConfig = serde_json::from_str(json).unwrap();
755 assert!(matches!(config, ProviderConfig::Custom { .. }));
756 assert_eq!(config.api_key(), Some("sk-custom"));
757 assert_eq!(config.api_endpoint(), Some("http://localhost:4000"));
758 }
759
760 #[test]
761 fn test_provider_config_helper_methods() {
762 let openai = ProviderConfig::openai(Some("sk-openai".to_string()));
763 assert_eq!(openai.provider_type(), "openai");
764 assert_eq!(openai.api_key(), Some("sk-openai"));
765
766 let anthropic =
767 ProviderConfig::anthropic(Some("sk-ant".to_string()), Some("oauth".to_string()));
768 assert_eq!(anthropic.provider_type(), "anthropic");
769 assert_eq!(anthropic.access_token(), Some("oauth"));
770
771 let gemini = ProviderConfig::gemini(Some("gemini-key".to_string()));
772 assert_eq!(gemini.provider_type(), "gemini");
773
774 let custom = ProviderConfig::custom(
775 "http://localhost:4000".to_string(),
776 Some("sk-custom".to_string()),
777 );
778 assert_eq!(custom.provider_type(), "custom");
779 assert_eq!(custom.api_endpoint(), Some("http://localhost:4000"));
780 }
781
782 #[test]
783 fn test_llm_provider_config_new() {
784 let config = LLMProviderConfig::new();
785 assert!(config.is_empty());
786 }
787
788 #[test]
789 fn test_llm_provider_config_add_and_get() {
790 let mut config = LLMProviderConfig::new();
791 config.add_provider(
792 "openai",
793 ProviderConfig::openai(Some("sk-test".to_string())),
794 );
795 config.add_provider(
796 "anthropic",
797 ProviderConfig::anthropic(Some("sk-ant".to_string()), None),
798 );
799
800 assert!(!config.is_empty());
801 assert!(config.get_provider("openai").is_some());
802 assert!(config.get_provider("anthropic").is_some());
803 assert!(config.get_provider("unknown").is_none());
804 }
805
806 #[test]
807 fn test_provider_config_toml_parsing() {
808 let json = r#"{
810 "openai": {"type": "openai", "api_key": "sk-openai"},
811 "anthropic": {"type": "anthropic", "api_key": "sk-ant", "access_token": "oauth"},
812 "litellm": {"type": "custom", "api_endpoint": "http://localhost:4000", "api_key": "sk-litellm"}
813 }"#;
814
815 let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
816 assert_eq!(providers.len(), 3);
817
818 assert!(matches!(
819 providers.get("openai"),
820 Some(ProviderConfig::OpenAI { .. })
821 ));
822 assert!(matches!(
823 providers.get("anthropic"),
824 Some(ProviderConfig::Anthropic { .. })
825 ));
826 assert!(matches!(
827 providers.get("litellm"),
828 Some(ProviderConfig::Custom { .. })
829 ));
830 }
831}