openai_agents_rust/model/
litellm.rs

1use crate::config::Config;
2use crate::error::AgentError;
3use crate::model::{Model, ModelResponse, ToolCall};
4use async_trait::async_trait;
5use reqwest::Client;
6use serde::Deserialize;
7
8/// Simple wrapper around any LLM that follows the OpenAI‑compatible API.
9pub struct LiteLLM {
10    client: Client,
11    config: Config,
12    base_url: String,
13    auth_token: Option<String>,
14}
15
16impl LiteLLM {
17    /// Create a new instance with the given configuration.
18    pub fn new(config: Config) -> Self {
19        let client = Client::builder()
20            .user_agent("openai-agents-rust")
21            .build()
22            .expect("Failed to build reqwest client");
23        let auth_token = if config.api_key.is_empty() {
24            None
25        } else {
26            Some(config.api_key.clone())
27        };
28    let base_url = config.base_url.clone();
29    Self { client, config, base_url, auth_token }
30    }
31
32    /// Override the base URL (e.g., http://192.168.3.40:8000/v1)
33    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
34        self.base_url = base_url.into();
35        self
36    }
37
38    /// Disable authentication for local/open endpoints.
39    pub fn without_auth(mut self) -> Self {
40        self.auth_token = None;
41        self
42    }
43
44    /// Set a custom auth token (Bearer). Use None to disable.
45    pub fn with_auth_token(mut self, token: Option<impl Into<String>>) -> Self {
46        self.auth_token = token.map(|t| t.into());
47        self
48    }
49}
50
51#[async_trait]
52impl Model for LiteLLM {
53    /// Sends a chat completion request to the configured endpoint.
54    async fn generate(&self, prompt: &str) -> Result<String, AgentError> {
55        // For demonstration we reuse the OpenAI chat endpoint.
56        let url = format!("{}/chat/completions", self.base_url);
57        let mut req = self.client.post(&url);
58        if let Some(token) = &self.auth_token {
59            req = req.bearer_auth(token);
60        }
61        let resp = req
62            .json(&serde_json::json!({
63                "model": self.config.model,
64                "messages": [{ "role": "user", "content": prompt }],
65                "max_tokens": 512,
66            }))
67            .send()
68            .await
69            .map_err(AgentError::from)?;
70
71        let text = resp.text().await.map_err(AgentError::from)?;
72        Ok(text)
73    }
74
75    async fn get_response(
76        &self,
77        system_instructions: Option<&str>,
78        input: &str,
79        _model_settings: Option<std::collections::HashMap<String, String>>,
80        messages: Option<&[serde_json::Value]>,
81        tools: Option<&[serde_json::Value]>,
82        tool_choice: Option<serde_json::Value>,
83        _output_schema: Option<&str>,
84        _handoffs: Option<&[String]>,
85        _tracing_enabled: bool,
86        _previous_response_id: Option<&str>,
87        _prompt_config: Option<&str>,
88    ) -> Result<ModelResponse, AgentError> {
89        let url = format!("{}/chat/completions", self.base_url);
90        let mut msgs: Vec<serde_json::Value> = Vec::new();
91        if let Some(provided) = messages {
92            msgs.extend_from_slice(provided);
93        } else {
94            if let Some(sys) = system_instructions {
95                msgs.push(serde_json::json!({"role": "system", "content": sys}));
96            }
97            msgs.push(serde_json::json!({"role": "user", "content": input}));
98        }
99
100        let mut req = self.client.post(&url);
101        if let Some(token) = &self.auth_token {
102            req = req.bearer_auth(token);
103        }
104        let resp = req
105            .json(&{
106                let mut payload = serde_json::json!({
107                    "model": self.config.model,
108                    "messages": msgs,
109                    "max_tokens": 512,
110                    "temperature": 0.2,
111                });
112                if let Some(t) = tools {
113                    payload["tools"] = serde_json::Value::Array(t.to_vec());
114                }
115                if let Some(choice) = tool_choice {
116                    payload["tool_choice"] = choice;
117                }
118                payload
119            })
120            .send()
121            .await
122            .map_err(AgentError::from)?;
123
124        #[derive(Deserialize)]
125        struct FunctionCall {
126            name: String,
127            arguments: serde_json::Value,
128        }
129        #[derive(Deserialize)]
130        struct ToolCallJson {
131            #[serde(rename = "type")]
132            _type: Option<String>,
133            id: Option<String>,
134            call_id: Option<String>,
135            function: Option<FunctionCall>,
136        }
137        #[derive(Deserialize)]
138        struct Message {
139            content: Option<String>,
140            tool_calls: Option<Vec<ToolCallJson>>,
141            function_call: Option<FunctionCall>,
142        }
143        #[derive(Deserialize)]
144        struct Choice {
145            message: Message,
146        }
147        #[derive(Deserialize)]
148        struct ChatCompletion {
149            choices: Vec<Choice>,
150        }
151
152        let status = resp.status();
153        let body_text = resp.text().await.map_err(AgentError::from)?;
154        if !status.is_success() {
155            return Err(AgentError::Other(format!(
156                "HTTP {} error: {}",
157                status, body_text
158            )));
159        }
160        match serde_json::from_str::<ChatCompletion>(&body_text) {
161            Ok(body) => {
162                let mut text: Option<String> = None;
163                let mut tool_calls: Vec<ToolCall> = Vec::new();
164                if let Some(first) = body.choices.into_iter().next() {
165                    text = first.message.content;
166                    if let Some(tcs) = first.message.tool_calls {
167                        for tc in tcs.into_iter() {
168                            if let Some(func) = tc.function {
169                                tool_calls.push(ToolCall {
170                                    id: tc.id,
171                                    name: func.name,
172                                    arguments: match func.arguments {
173                                        serde_json::Value::String(s) => s,
174                                        other => other.to_string(),
175                                    },
176                                    call_id: tc.call_id,
177                                });
178                            }
179                        }
180                    } else if let Some(func) = first.message.function_call {
181                        tool_calls.push(ToolCall {
182                            id: None,
183                            name: func.name,
184                            arguments: match func.arguments {
185                                serde_json::Value::String(s) => s,
186                                other => other.to_string(),
187                            },
188                            call_id: None,
189                        });
190                    }
191                }
192                Ok(ModelResponse {
193                    id: None,
194                    text,
195                    tool_calls,
196                })
197            }
198            Err(_) => {
199                if let Ok(v) = serde_json::from_str::<serde_json::Value>(&body_text) {
200                    let text = v
201                        .get("choices")
202                        .and_then(|c| c.as_array())
203                        .and_then(|arr| arr.get(0))
204                        .and_then(|c0| {
205                            c0.get("message")
206                                .and_then(|m| m.get("content"))
207                                .and_then(|t| t.as_str())
208                                .map(|s| s.to_string())
209                                .or_else(|| {
210                                    c0.get("text")
211                                        .and_then(|t| t.as_str())
212                                        .map(|s| s.to_string())
213                                })
214                        });
215                    return Ok(ModelResponse {
216                        id: None,
217                        text,
218                        tool_calls: vec![],
219                    });
220                }
221                Ok(ModelResponse {
222                    id: None,
223                    text: Some(body_text),
224                    tool_calls: vec![],
225                })
226            }
227        }
228    }
229}