Skip to main content

elizaos_plugin_copilot_proxy/
client.rs

1//! HTTP client for the Copilot Proxy server.
2
3use reqwest::{header::CONTENT_TYPE, Client};
4use std::time::Duration;
5use tracing::debug;
6
7use crate::config::CopilotProxyConfig;
8use crate::error::{CopilotProxyError, Result};
9use crate::types::{
10    ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ModelsResponse,
11    TextGenerationParams, TextGenerationResult,
12};
13
14/// HTTP client for interacting with the Copilot Proxy server.
15pub struct CopilotProxyClient {
16    client: Client,
17    config: CopilotProxyConfig,
18}
19
20impl CopilotProxyClient {
21    /// Create a new Copilot Proxy client.
22    pub fn new(config: CopilotProxyConfig) -> Result<Self> {
23        config.validate()?;
24
25        let client = Client::builder()
26            .default_headers({
27                let mut headers = reqwest::header::HeaderMap::new();
28                headers.insert(CONTENT_TYPE, "application/json".parse().unwrap());
29                headers
30            })
31            .timeout(Duration::from_secs(config.timeout_secs))
32            .build()?;
33
34        Ok(Self { client, config })
35    }
36
37    /// Create a client from environment variables.
38    pub fn from_env() -> Result<Self> {
39        Self::new(CopilotProxyConfig::from_env())
40    }
41
42    /// Get the base URL.
43    pub fn base_url(&self) -> &str {
44        &self.config.base_url
45    }
46
47    /// Get the configuration.
48    pub fn config(&self) -> &CopilotProxyConfig {
49        &self.config
50    }
51
52    /// Build a URL for an endpoint.
53    fn url(&self, endpoint: &str) -> String {
54        format!("{}{}", self.config.base_url, endpoint)
55    }
56
57    /// Check the response for errors.
58    async fn check_response(
59        &self,
60        response: reqwest::Response,
61    ) -> Result<reqwest::Response> {
62        if response.status().is_success() {
63            return Ok(response);
64        }
65
66        let status = response.status().as_u16();
67        let message = response
68            .text()
69            .await
70            .unwrap_or_else(|_| "Unknown error".to_string());
71
72        // Try to parse as JSON error
73        let message = serde_json::from_str::<serde_json::Value>(&message)
74            .ok()
75            .and_then(|v| v["error"]["message"].as_str().map(String::from))
76            .unwrap_or(message);
77
78        Err(CopilotProxyError::ApiError { status, message })
79    }
80
81    /// List available models.
82    pub async fn list_models(&self) -> Result<ModelsResponse> {
83        debug!("Listing Copilot Proxy models");
84        let response = self.client.get(self.url("/models")).send().await?;
85        let response = self.check_response(response).await?;
86        Ok(response.json().await?)
87    }
88
89    /// Check if the proxy server is available.
90    pub async fn health_check(&self) -> bool {
91        match self.list_models().await {
92            Ok(_) => true,
93            Err(_) => false,
94        }
95    }
96
97    /// Create a chat completion.
98    pub async fn create_chat_completion(
99        &self,
100        request: &ChatCompletionRequest,
101    ) -> Result<ChatCompletionResponse> {
102        debug!("Creating chat completion with model: {}", request.model);
103
104        let response = self
105            .client
106            .post(self.url("/chat/completions"))
107            .json(request)
108            .send()
109            .await?;
110        let response = self.check_response(response).await?;
111
112        Ok(response.json().await?)
113    }
114
115    /// Generate text using the chat completion API.
116    pub async fn generate_text(&self, params: &TextGenerationParams) -> Result<TextGenerationResult> {
117        let model = params
118            .model
119            .as_deref()
120            .unwrap_or(&self.config.large_model);
121        debug!("Generating text with model: {}", model);
122
123        let mut messages = Vec::new();
124
125        if let Some(system) = &params.system {
126            messages.push(ChatMessage::system(system));
127        }
128
129        messages.push(ChatMessage::user(&params.prompt));
130
131        let mut request = ChatCompletionRequest::new(model, messages);
132
133        if let Some(max_tokens) = params.max_tokens {
134            request = request.max_tokens(max_tokens);
135        } else {
136            request = request.max_tokens(self.config.max_tokens);
137        }
138
139        if let Some(temp) = params.temperature {
140            request = request.temperature(temp);
141        }
142
143        if let Some(fp) = params.frequency_penalty {
144            request = request.frequency_penalty(fp);
145        }
146
147        if let Some(pp) = params.presence_penalty {
148            request = request.presence_penalty(pp);
149        }
150
151        if let Some(stop) = &params.stop {
152            request = request.stop(stop.clone());
153        }
154
155        let response = self.create_chat_completion(&request).await?;
156
157        let text = response
158            .choices
159            .first()
160            .and_then(|c| c.message.content.clone())
161            .ok_or(CopilotProxyError::EmptyResponse)?;
162
163        Ok(TextGenerationResult {
164            text,
165            usage: response.usage,
166        })
167    }
168
169    /// Generate text using the small model.
170    pub async fn generate_text_small(&self, prompt: &str) -> Result<String> {
171        let params = TextGenerationParams::new(prompt)
172            .model(&self.config.small_model);
173        let result = self.generate_text(&params).await?;
174        Ok(result.text)
175    }
176
177    /// Generate text using the large model.
178    pub async fn generate_text_large(&self, prompt: &str) -> Result<String> {
179        let params = TextGenerationParams::new(prompt)
180            .model(&self.config.large_model);
181        let result = self.generate_text(&params).await?;
182        Ok(result.text)
183    }
184
185    /// Generate a JSON object using the specified model.
186    pub async fn generate_object(
187        &self,
188        prompt: &str,
189        model: Option<&str>,
190    ) -> Result<serde_json::Value> {
191        let json_prompt = format!(
192            "{}\nPlease respond with valid JSON only, without any explanations, markdown formatting, or additional text.",
193            prompt
194        );
195
196        let params = TextGenerationParams::new(json_prompt)
197            .model(model.unwrap_or(&self.config.small_model))
198            .system("You must respond with valid JSON only. No markdown, no code blocks, no explanation text.")
199            .temperature(0.2);
200
201        let result = self.generate_text(&params).await?;
202        extract_json(&result.text)
203    }
204}
205
206/// Extract JSON from a text response.
207fn extract_json(text: &str) -> Result<serde_json::Value> {
208    // Try direct parse first
209    if let Ok(value) = serde_json::from_str(text) {
210        return Ok(value);
211    }
212
213    // Try extracting from JSON code block
214    let json_block_re = regex::Regex::new(r"```json\s*([\s\S]*?)\s*```").ok();
215    if let Some(re) = &json_block_re {
216        if let Some(caps) = re.captures(text) {
217            if let Some(content) = caps.get(1) {
218                if let Ok(value) = serde_json::from_str(content.as_str().trim()) {
219                    return Ok(value);
220                }
221            }
222        }
223    }
224
225    // Try extracting from any code block
226    let any_block_re = regex::Regex::new(r"```(?:\w*)\s*([\s\S]*?)\s*```").ok();
227    if let Some(re) = &any_block_re {
228        if let Some(caps) = re.captures(text) {
229            if let Some(content) = caps.get(1) {
230                let trimmed = content.as_str().trim();
231                if trimmed.starts_with('{') && trimmed.ends_with('}') {
232                    if let Ok(value) = serde_json::from_str(trimmed) {
233                        return Ok(value);
234                    }
235                }
236            }
237        }
238    }
239
240    // Try finding JSON object in text
241    if let Some(json_obj) = find_json_object(text) {
242        if let Ok(value) = serde_json::from_str(&json_obj) {
243            return Ok(value);
244        }
245    }
246
247    Err(CopilotProxyError::JsonExtractionError(
248        "Could not extract valid JSON from response".to_string(),
249    ))
250}
251
252/// Find a JSON object in text.
253fn find_json_object(text: &str) -> Option<String> {
254    let trimmed = text.trim();
255    if trimmed.starts_with('{') && trimmed.ends_with('}') {
256        return Some(trimmed.to_string());
257    }
258
259    let mut best: Option<String> = None;
260    let mut depth = 0;
261    let mut start: Option<usize> = None;
262
263    for (i, char) in text.chars().enumerate() {
264        if char == '{' {
265            if depth == 0 {
266                start = Some(i);
267            }
268            depth += 1;
269        } else if char == '}' {
270            depth -= 1;
271            if depth == 0 {
272                if let Some(s) = start {
273                    let candidate = text[s..=i].to_string();
274                    if best.as_ref().map(|b| candidate.len() > b.len()).unwrap_or(true) {
275                        best = Some(candidate);
276                    }
277                }
278            }
279        }
280    }
281
282    best
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[test]
290    fn test_extract_json_direct() {
291        let json = r#"{"message": "hello"}"#;
292        let result = extract_json(json).unwrap();
293        assert_eq!(result["message"], "hello");
294    }
295
296    #[test]
297    fn test_extract_json_code_block() {
298        let text = r#"Here is the response:
299```json
300{"message": "hello"}
301```"#;
302        let result = extract_json(text).unwrap();
303        assert_eq!(result["message"], "hello");
304    }
305
306    #[test]
307    fn test_extract_json_embedded() {
308        let text = r#"The answer is {"message": "hello"} as you can see."#;
309        let result = extract_json(text).unwrap();
310        assert_eq!(result["message"], "hello");
311    }
312
313    #[test]
314    fn test_extract_json_fails_for_plain_text() {
315        let text = "This is not JSON at all.";
316        let result = extract_json(text);
317        assert!(result.is_err());
318    }
319
320    #[test]
321    fn test_extract_json_any_code_block() {
322        let text = "Result:\n```\n{\"key\": 42}\n```";
323        let result = extract_json(text).unwrap();
324        assert_eq!(result["key"], 42);
325    }
326
327    #[test]
328    fn test_extract_json_nested_objects() {
329        let text = r#"{"outer": {"inner": "value"}}"#;
330        let result = extract_json(text).unwrap();
331        assert_eq!(result["outer"]["inner"], "value");
332    }
333
334    #[test]
335    fn test_find_json_object_picks_largest() {
336        let text = r#"small: {"a": 1} and large: {"b": 2, "c": 3}"#;
337        let found = find_json_object(text).unwrap();
338        // The larger JSON object should be picked
339        let parsed: serde_json::Value = serde_json::from_str(&found).unwrap();
340        assert!(parsed.get("b").is_some() || parsed.get("a").is_some());
341    }
342
343    #[test]
344    fn test_client_url_construction() {
345        let config = CopilotProxyConfig::new().base_url("http://localhost:9999/v1");
346        let client = CopilotProxyClient::new(config).unwrap();
347        assert_eq!(client.base_url(), "http://localhost:9999/v1");
348    }
349
350    #[test]
351    fn test_client_creation_with_empty_base_url_fails() {
352        let config = CopilotProxyConfig {
353            base_url: "".to_string(),
354            ..CopilotProxyConfig::new()
355        };
356        let result = CopilotProxyClient::new(config);
357        assert!(result.is_err());
358    }
359
360    #[tokio::test]
361    async fn test_health_check_unreachable_returns_false() {
362        let config = CopilotProxyConfig::new()
363            .base_url("http://127.0.0.1:1")
364            .timeout_secs(1);
365        let client = CopilotProxyClient::new(config).unwrap();
366        assert!(!client.health_check().await);
367    }
368
369    #[test]
370    fn test_check_response_builds_api_error() {
371        // Test the error type directly
372        let err = CopilotProxyError::ApiError {
373            status: 429,
374            message: "Rate limited".to_string(),
375        };
376        let msg = format!("{}", err);
377        assert!(msg.contains("429"));
378        assert!(msg.contains("Rate limited"));
379    }
380
381    #[test]
382    fn test_empty_response_error() {
383        let err = CopilotProxyError::EmptyResponse;
384        let msg = format!("{}", err);
385        assert!(msg.to_lowercase().contains("empty"));
386    }
387}