openai_agents_rust/model/
litellm.rs1use crate::config::Config;
2use crate::error::AgentError;
3use crate::model::{Model, ModelResponse, ToolCall};
4use async_trait::async_trait;
5use reqwest::Client;
6use serde::Deserialize;
7
8pub struct LiteLLM {
10 client: Client,
11 config: Config,
12 base_url: String,
13 auth_token: Option<String>,
14}
15
16impl LiteLLM {
17 pub fn new(config: Config) -> Self {
19 let client = Client::builder()
20 .user_agent("openai-agents-rust")
21 .build()
22 .expect("Failed to build reqwest client");
23 let auth_token = if config.api_key.is_empty() {
24 None
25 } else {
26 Some(config.api_key.clone())
27 };
28 let base_url = config.base_url.clone();
29 Self { client, config, base_url, auth_token }
30 }
31
32 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
34 self.base_url = base_url.into();
35 self
36 }
37
38 pub fn without_auth(mut self) -> Self {
40 self.auth_token = None;
41 self
42 }
43
44 pub fn with_auth_token(mut self, token: Option<impl Into<String>>) -> Self {
46 self.auth_token = token.map(|t| t.into());
47 self
48 }
49}
50
51#[async_trait]
52impl Model for LiteLLM {
53 async fn generate(&self, prompt: &str) -> Result<String, AgentError> {
55 let url = format!("{}/chat/completions", self.base_url);
57 let mut req = self.client.post(&url);
58 if let Some(token) = &self.auth_token {
59 req = req.bearer_auth(token);
60 }
61 let resp = req
62 .json(&serde_json::json!({
63 "model": self.config.model,
64 "messages": [{ "role": "user", "content": prompt }],
65 "max_tokens": 512,
66 }))
67 .send()
68 .await
69 .map_err(AgentError::from)?;
70
71 let text = resp.text().await.map_err(AgentError::from)?;
72 Ok(text)
73 }
74
75 async fn get_response(
76 &self,
77 system_instructions: Option<&str>,
78 input: &str,
79 _model_settings: Option<std::collections::HashMap<String, String>>,
80 messages: Option<&[serde_json::Value]>,
81 tools: Option<&[serde_json::Value]>,
82 tool_choice: Option<serde_json::Value>,
83 _output_schema: Option<&str>,
84 _handoffs: Option<&[String]>,
85 _tracing_enabled: bool,
86 _previous_response_id: Option<&str>,
87 _prompt_config: Option<&str>,
88 ) -> Result<ModelResponse, AgentError> {
89 let url = format!("{}/chat/completions", self.base_url);
90 let mut msgs: Vec<serde_json::Value> = Vec::new();
91 if let Some(provided) = messages {
92 msgs.extend_from_slice(provided);
93 } else {
94 if let Some(sys) = system_instructions {
95 msgs.push(serde_json::json!({"role": "system", "content": sys}));
96 }
97 msgs.push(serde_json::json!({"role": "user", "content": input}));
98 }
99
100 let mut req = self.client.post(&url);
101 if let Some(token) = &self.auth_token {
102 req = req.bearer_auth(token);
103 }
104 let resp = req
105 .json(&{
106 let mut payload = serde_json::json!({
107 "model": self.config.model,
108 "messages": msgs,
109 "max_tokens": 512,
110 "temperature": 0.2,
111 });
112 if let Some(t) = tools {
113 payload["tools"] = serde_json::Value::Array(t.to_vec());
114 }
115 if let Some(choice) = tool_choice {
116 payload["tool_choice"] = choice;
117 }
118 payload
119 })
120 .send()
121 .await
122 .map_err(AgentError::from)?;
123
124 #[derive(Deserialize)]
125 struct FunctionCall {
126 name: String,
127 arguments: serde_json::Value,
128 }
129 #[derive(Deserialize)]
130 struct ToolCallJson {
131 #[serde(rename = "type")]
132 _type: Option<String>,
133 id: Option<String>,
134 call_id: Option<String>,
135 function: Option<FunctionCall>,
136 }
137 #[derive(Deserialize)]
138 struct Message {
139 content: Option<String>,
140 tool_calls: Option<Vec<ToolCallJson>>,
141 function_call: Option<FunctionCall>,
142 }
143 #[derive(Deserialize)]
144 struct Choice {
145 message: Message,
146 }
147 #[derive(Deserialize)]
148 struct ChatCompletion {
149 choices: Vec<Choice>,
150 }
151
152 let status = resp.status();
153 let body_text = resp.text().await.map_err(AgentError::from)?;
154 if !status.is_success() {
155 return Err(AgentError::Other(format!(
156 "HTTP {} error: {}",
157 status, body_text
158 )));
159 }
160 match serde_json::from_str::<ChatCompletion>(&body_text) {
161 Ok(body) => {
162 let mut text: Option<String> = None;
163 let mut tool_calls: Vec<ToolCall> = Vec::new();
164 if let Some(first) = body.choices.into_iter().next() {
165 text = first.message.content;
166 if let Some(tcs) = first.message.tool_calls {
167 for tc in tcs.into_iter() {
168 if let Some(func) = tc.function {
169 tool_calls.push(ToolCall {
170 id: tc.id,
171 name: func.name,
172 arguments: match func.arguments {
173 serde_json::Value::String(s) => s,
174 other => other.to_string(),
175 },
176 call_id: tc.call_id,
177 });
178 }
179 }
180 } else if let Some(func) = first.message.function_call {
181 tool_calls.push(ToolCall {
182 id: None,
183 name: func.name,
184 arguments: match func.arguments {
185 serde_json::Value::String(s) => s,
186 other => other.to_string(),
187 },
188 call_id: None,
189 });
190 }
191 }
192 Ok(ModelResponse {
193 id: None,
194 text,
195 tool_calls,
196 })
197 }
198 Err(_) => {
199 if let Ok(v) = serde_json::from_str::<serde_json::Value>(&body_text) {
200 let text = v
201 .get("choices")
202 .and_then(|c| c.as_array())
203 .and_then(|arr| arr.get(0))
204 .and_then(|c0| {
205 c0.get("message")
206 .and_then(|m| m.get("content"))
207 .and_then(|t| t.as_str())
208 .map(|s| s.to_string())
209 .or_else(|| {
210 c0.get("text")
211 .and_then(|t| t.as_str())
212 .map(|s| s.to_string())
213 })
214 });
215 return Ok(ModelResponse {
216 id: None,
217 text,
218 tool_calls: vec![],
219 });
220 }
221 Ok(ModelResponse {
222 id: None,
223 text: Some(body_text),
224 tool_calls: vec![],
225 })
226 }
227 }
228 }
229}