Skip to main content

cognee_llm/
config.rs

1//! LLM configuration.
2
3use serde::{Deserialize, Serialize};
4
5/// LLM provider type.
6#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
7#[serde(rename_all = "lowercase")]
8pub enum LlmProvider {
9    OpenAI,
10    LiteRt,
11    Anthropic,
12    Ollama,
13    Gemini,
14    Mistral,
15    Bedrock,
16    Local,
17    /// Record/replay cassette-based mock LLM (selected via `MOCK_LLM`).
18    Mock,
19}
20
21/// LLM configuration.
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct LlmConfig {
24    /// LLM provider.
25    pub provider: LlmProvider,
26
27    /// Model identifier (e.g., "gpt-4", "claude-3-opus").
28    pub model: String,
29
30    /// API key (if required by provider).
31    pub api_key: Option<String>,
32
33    /// API endpoint (custom endpoint for self-hosted models).
34    pub endpoint: Option<String>,
35
36    /// Default temperature.
37    pub temperature: f32,
38
39    /// Default max tokens.
40    pub max_tokens: u32,
41
42    /// Enable streaming responses.
43    pub streaming: bool,
44
45    /// Request timeout in seconds.
46    pub timeout_seconds: u64,
47
48    /// Maximum number of retries.
49    pub max_retries: u32,
50
51    /// Enable rate limiting.
52    pub rate_limit_enabled: bool,
53
54    /// Rate limit: requests per interval.
55    pub rate_limit_requests: u32,
56
57    /// Rate limit interval in seconds.
58    pub rate_limit_interval_seconds: u64,
59
60    /// Fallback model (if primary fails).
61    pub fallback_model: Option<String>,
62
63    /// Fallback API key.
64    pub fallback_api_key: Option<String>,
65}
66
67impl Default for LlmConfig {
68    fn default() -> Self {
69        Self {
70            provider: LlmProvider::OpenAI,
71            model: "gpt-4".to_string(),
72            api_key: None,
73            endpoint: None,
74            temperature: 0.0,
75            max_tokens: 16384,
76            streaming: false,
77            timeout_seconds: 120,
78            max_retries: 3,
79            rate_limit_enabled: false,
80            rate_limit_requests: 60,
81            rate_limit_interval_seconds: 60,
82            fallback_model: None,
83            fallback_api_key: None,
84        }
85    }
86}
87
88impl LlmConfig {
89    /// Create a new configuration with minimal required fields.
90    pub fn new(provider: LlmProvider, model: impl Into<String>) -> Self {
91        Self {
92            provider,
93            model: model.into(),
94            ..Default::default()
95        }
96    }
97
98    /// Set API key.
99    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
100        self.api_key = Some(api_key.into());
101        self
102    }
103
104    /// Set custom endpoint.
105    pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
106        self.endpoint = Some(endpoint.into());
107        self
108    }
109
110    /// Set temperature.
111    pub fn with_temperature(mut self, temperature: f32) -> Self {
112        self.temperature = temperature;
113        self
114    }
115
116    /// Set max tokens.
117    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
118        self.max_tokens = max_tokens;
119        self
120    }
121
122    /// Set timeout.
123    pub fn with_timeout_seconds(mut self, timeout_seconds: u64) -> Self {
124        self.timeout_seconds = timeout_seconds;
125        self
126    }
127
128    /// Set max retries.
129    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
130        self.max_retries = max_retries;
131        self
132    }
133}