elizaos_plugin_copilot_proxy/
client.rs1use reqwest::{header::CONTENT_TYPE, Client};
4use std::time::Duration;
5use tracing::debug;
6
7use crate::config::CopilotProxyConfig;
8use crate::error::{CopilotProxyError, Result};
9use crate::types::{
10 ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ModelsResponse,
11 TextGenerationParams, TextGenerationResult,
12};
13
14pub struct CopilotProxyClient {
16 client: Client,
17 config: CopilotProxyConfig,
18}
19
20impl CopilotProxyClient {
21 pub fn new(config: CopilotProxyConfig) -> Result<Self> {
23 config.validate()?;
24
25 let client = Client::builder()
26 .default_headers({
27 let mut headers = reqwest::header::HeaderMap::new();
28 headers.insert(CONTENT_TYPE, "application/json".parse().unwrap());
29 headers
30 })
31 .timeout(Duration::from_secs(config.timeout_secs))
32 .build()?;
33
34 Ok(Self { client, config })
35 }
36
37 pub fn from_env() -> Result<Self> {
39 Self::new(CopilotProxyConfig::from_env())
40 }
41
42 pub fn base_url(&self) -> &str {
44 &self.config.base_url
45 }
46
47 pub fn config(&self) -> &CopilotProxyConfig {
49 &self.config
50 }
51
52 fn url(&self, endpoint: &str) -> String {
54 format!("{}{}", self.config.base_url, endpoint)
55 }
56
57 async fn check_response(
59 &self,
60 response: reqwest::Response,
61 ) -> Result<reqwest::Response> {
62 if response.status().is_success() {
63 return Ok(response);
64 }
65
66 let status = response.status().as_u16();
67 let message = response
68 .text()
69 .await
70 .unwrap_or_else(|_| "Unknown error".to_string());
71
72 let message = serde_json::from_str::<serde_json::Value>(&message)
74 .ok()
75 .and_then(|v| v["error"]["message"].as_str().map(String::from))
76 .unwrap_or(message);
77
78 Err(CopilotProxyError::ApiError { status, message })
79 }
80
81 pub async fn list_models(&self) -> Result<ModelsResponse> {
83 debug!("Listing Copilot Proxy models");
84 let response = self.client.get(self.url("/models")).send().await?;
85 let response = self.check_response(response).await?;
86 Ok(response.json().await?)
87 }
88
89 pub async fn health_check(&self) -> bool {
91 match self.list_models().await {
92 Ok(_) => true,
93 Err(_) => false,
94 }
95 }
96
97 pub async fn create_chat_completion(
99 &self,
100 request: &ChatCompletionRequest,
101 ) -> Result<ChatCompletionResponse> {
102 debug!("Creating chat completion with model: {}", request.model);
103
104 let response = self
105 .client
106 .post(self.url("/chat/completions"))
107 .json(request)
108 .send()
109 .await?;
110 let response = self.check_response(response).await?;
111
112 Ok(response.json().await?)
113 }
114
115 pub async fn generate_text(&self, params: &TextGenerationParams) -> Result<TextGenerationResult> {
117 let model = params
118 .model
119 .as_deref()
120 .unwrap_or(&self.config.large_model);
121 debug!("Generating text with model: {}", model);
122
123 let mut messages = Vec::new();
124
125 if let Some(system) = ¶ms.system {
126 messages.push(ChatMessage::system(system));
127 }
128
129 messages.push(ChatMessage::user(¶ms.prompt));
130
131 let mut request = ChatCompletionRequest::new(model, messages);
132
133 if let Some(max_tokens) = params.max_tokens {
134 request = request.max_tokens(max_tokens);
135 } else {
136 request = request.max_tokens(self.config.max_tokens);
137 }
138
139 if let Some(temp) = params.temperature {
140 request = request.temperature(temp);
141 }
142
143 if let Some(fp) = params.frequency_penalty {
144 request = request.frequency_penalty(fp);
145 }
146
147 if let Some(pp) = params.presence_penalty {
148 request = request.presence_penalty(pp);
149 }
150
151 if let Some(stop) = ¶ms.stop {
152 request = request.stop(stop.clone());
153 }
154
155 let response = self.create_chat_completion(&request).await?;
156
157 let text = response
158 .choices
159 .first()
160 .and_then(|c| c.message.content.clone())
161 .ok_or(CopilotProxyError::EmptyResponse)?;
162
163 Ok(TextGenerationResult {
164 text,
165 usage: response.usage,
166 })
167 }
168
169 pub async fn generate_text_small(&self, prompt: &str) -> Result<String> {
171 let params = TextGenerationParams::new(prompt)
172 .model(&self.config.small_model);
173 let result = self.generate_text(¶ms).await?;
174 Ok(result.text)
175 }
176
177 pub async fn generate_text_large(&self, prompt: &str) -> Result<String> {
179 let params = TextGenerationParams::new(prompt)
180 .model(&self.config.large_model);
181 let result = self.generate_text(¶ms).await?;
182 Ok(result.text)
183 }
184
185 pub async fn generate_object(
187 &self,
188 prompt: &str,
189 model: Option<&str>,
190 ) -> Result<serde_json::Value> {
191 let json_prompt = format!(
192 "{}\nPlease respond with valid JSON only, without any explanations, markdown formatting, or additional text.",
193 prompt
194 );
195
196 let params = TextGenerationParams::new(json_prompt)
197 .model(model.unwrap_or(&self.config.small_model))
198 .system("You must respond with valid JSON only. No markdown, no code blocks, no explanation text.")
199 .temperature(0.2);
200
201 let result = self.generate_text(¶ms).await?;
202 extract_json(&result.text)
203 }
204}
205
206fn extract_json(text: &str) -> Result<serde_json::Value> {
208 if let Ok(value) = serde_json::from_str(text) {
210 return Ok(value);
211 }
212
213 let json_block_re = regex::Regex::new(r"```json\s*([\s\S]*?)\s*```").ok();
215 if let Some(re) = &json_block_re {
216 if let Some(caps) = re.captures(text) {
217 if let Some(content) = caps.get(1) {
218 if let Ok(value) = serde_json::from_str(content.as_str().trim()) {
219 return Ok(value);
220 }
221 }
222 }
223 }
224
225 let any_block_re = regex::Regex::new(r"```(?:\w*)\s*([\s\S]*?)\s*```").ok();
227 if let Some(re) = &any_block_re {
228 if let Some(caps) = re.captures(text) {
229 if let Some(content) = caps.get(1) {
230 let trimmed = content.as_str().trim();
231 if trimmed.starts_with('{') && trimmed.ends_with('}') {
232 if let Ok(value) = serde_json::from_str(trimmed) {
233 return Ok(value);
234 }
235 }
236 }
237 }
238 }
239
240 if let Some(json_obj) = find_json_object(text) {
242 if let Ok(value) = serde_json::from_str(&json_obj) {
243 return Ok(value);
244 }
245 }
246
247 Err(CopilotProxyError::JsonExtractionError(
248 "Could not extract valid JSON from response".to_string(),
249 ))
250}
251
252fn find_json_object(text: &str) -> Option<String> {
254 let trimmed = text.trim();
255 if trimmed.starts_with('{') && trimmed.ends_with('}') {
256 return Some(trimmed.to_string());
257 }
258
259 let mut best: Option<String> = None;
260 let mut depth = 0;
261 let mut start: Option<usize> = None;
262
263 for (i, char) in text.chars().enumerate() {
264 if char == '{' {
265 if depth == 0 {
266 start = Some(i);
267 }
268 depth += 1;
269 } else if char == '}' {
270 depth -= 1;
271 if depth == 0 {
272 if let Some(s) = start {
273 let candidate = text[s..=i].to_string();
274 if best.as_ref().map(|b| candidate.len() > b.len()).unwrap_or(true) {
275 best = Some(candidate);
276 }
277 }
278 }
279 }
280 }
281
282 best
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
290 fn test_extract_json_direct() {
291 let json = r#"{"message": "hello"}"#;
292 let result = extract_json(json).unwrap();
293 assert_eq!(result["message"], "hello");
294 }
295
296 #[test]
297 fn test_extract_json_code_block() {
298 let text = r#"Here is the response:
299```json
300{"message": "hello"}
301```"#;
302 let result = extract_json(text).unwrap();
303 assert_eq!(result["message"], "hello");
304 }
305
306 #[test]
307 fn test_extract_json_embedded() {
308 let text = r#"The answer is {"message": "hello"} as you can see."#;
309 let result = extract_json(text).unwrap();
310 assert_eq!(result["message"], "hello");
311 }
312
313 #[test]
314 fn test_extract_json_fails_for_plain_text() {
315 let text = "This is not JSON at all.";
316 let result = extract_json(text);
317 assert!(result.is_err());
318 }
319
320 #[test]
321 fn test_extract_json_any_code_block() {
322 let text = "Result:\n```\n{\"key\": 42}\n```";
323 let result = extract_json(text).unwrap();
324 assert_eq!(result["key"], 42);
325 }
326
327 #[test]
328 fn test_extract_json_nested_objects() {
329 let text = r#"{"outer": {"inner": "value"}}"#;
330 let result = extract_json(text).unwrap();
331 assert_eq!(result["outer"]["inner"], "value");
332 }
333
334 #[test]
335 fn test_find_json_object_picks_largest() {
336 let text = r#"small: {"a": 1} and large: {"b": 2, "c": 3}"#;
337 let found = find_json_object(text).unwrap();
338 let parsed: serde_json::Value = serde_json::from_str(&found).unwrap();
340 assert!(parsed.get("b").is_some() || parsed.get("a").is_some());
341 }
342
343 #[test]
344 fn test_client_url_construction() {
345 let config = CopilotProxyConfig::new().base_url("http://localhost:9999/v1");
346 let client = CopilotProxyClient::new(config).unwrap();
347 assert_eq!(client.base_url(), "http://localhost:9999/v1");
348 }
349
350 #[test]
351 fn test_client_creation_with_empty_base_url_fails() {
352 let config = CopilotProxyConfig {
353 base_url: "".to_string(),
354 ..CopilotProxyConfig::new()
355 };
356 let result = CopilotProxyClient::new(config);
357 assert!(result.is_err());
358 }
359
360 #[tokio::test]
361 async fn test_health_check_unreachable_returns_false() {
362 let config = CopilotProxyConfig::new()
363 .base_url("http://127.0.0.1:1")
364 .timeout_secs(1);
365 let client = CopilotProxyClient::new(config).unwrap();
366 assert!(!client.health_check().await);
367 }
368
369 #[test]
370 fn test_check_response_builds_api_error() {
371 let err = CopilotProxyError::ApiError {
373 status: 429,
374 message: "Rate limited".to_string(),
375 };
376 let msg = format!("{}", err);
377 assert!(msg.contains("429"));
378 assert!(msg.contains("Rate limited"));
379 }
380
381 #[test]
382 fn test_empty_response_error() {
383 let err = CopilotProxyError::EmptyResponse;
384 let msg = format!("{}", err);
385 assert!(msg.to_lowercase().contains("empty"));
386 }
387}