agentroot_core/llm/
client.rs

1//! HTTP client for external LLM services (vLLM, OpenAI, etc.)
2
3use crate::config::LLMServiceConfig;
4use crate::error::{AgentRootError, Result};
5use crate::llm::{DocumentMetadata, MetadataContext};
6use async_trait::async_trait;
7use futures::stream;
8use serde::{Deserialize, Serialize};
9use std::sync::{atomic::AtomicU64, Arc};
10use std::time::{Duration, Instant};
11
12/// Trait for LLM service clients
13#[async_trait]
14pub trait LLMClient: Send + Sync {
15    /// Generate chat completion
16    async fn chat_completion(&self, messages: Vec<ChatMessage>) -> Result<String>;
17
18    /// Generate embeddings for text
19    async fn embed(&self, text: &str) -> Result<Vec<f32>>;
20
21    /// Generate embeddings for multiple texts
22    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
23
24    /// Get embedding dimensions
25    fn embedding_dimensions(&self) -> usize;
26
27    /// Get model name
28    fn model_name(&self) -> &str;
29}
30
31/// Chat message for completion requests
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct ChatMessage {
34    pub role: String,
35    pub content: String,
36}
37
38impl ChatMessage {
39    pub fn system(content: impl Into<String>) -> Self {
40        Self {
41            role: "system".to_string(),
42            content: content.into(),
43        }
44    }
45
46    pub fn user(content: impl Into<String>) -> Self {
47        Self {
48            role: "user".to_string(),
49            content: content.into(),
50        }
51    }
52}
53
54/// API metrics for monitoring
55#[derive(Debug, Default)]
56pub struct APIMetrics {
57    pub total_requests: AtomicU64,
58    pub total_errors: AtomicU64,
59    pub cache_hits: AtomicU64,
60    pub cache_misses: AtomicU64,
61    pub total_latency_ms: AtomicU64,
62}
63
64/// vLLM/OpenAI-compatible client
65pub struct VLLMClient {
66    http_client: reqwest::Client,
67    config: LLMServiceConfig,
68    embedding_dimensions: usize,
69    cache: Arc<super::cache::LLMCache>,
70    metrics: Arc<APIMetrics>,
71}
72
73impl VLLMClient {
74    /// Create new vLLM client from configuration
75    pub fn new(config: LLMServiceConfig) -> Result<Self> {
76        let http_client = reqwest::Client::builder()
77            .timeout(Duration::from_secs(config.timeout_secs))
78            .build()
79            .map_err(AgentRootError::Http)?;
80
81        // Use configured dimensions or default to 384
82        let embedding_dimensions = config.embedding_dimensions.unwrap_or(384);
83
84        // Enable caching by default (1 hour TTL)
85        let cache = Arc::new(super::cache::LLMCache::new());
86
87        // Initialize metrics
88        let metrics = Arc::new(APIMetrics::default());
89
90        Ok(Self {
91            http_client,
92            config,
93            embedding_dimensions,
94            cache,
95            metrics,
96        })
97    }
98
99    /// Create from environment variables
100    pub fn from_env() -> Result<Self> {
101        let config = LLMServiceConfig::default();
102        Self::new(config)
103    }
104
105    /// Get current API metrics
106    pub fn metrics(&self) -> MetricsSnapshot {
107        use std::sync::atomic::Ordering;
108
109        let total = self.metrics.total_requests.load(Ordering::Relaxed);
110        let hits = self.metrics.cache_hits.load(Ordering::Relaxed);
111        let misses = self.metrics.cache_misses.load(Ordering::Relaxed);
112
113        MetricsSnapshot {
114            total_requests: total,
115            total_errors: self.metrics.total_errors.load(Ordering::Relaxed),
116            cache_hits: hits,
117            cache_misses: misses,
118            cache_hit_rate: if total > 0 {
119                hits as f64 / total as f64 * 100.0
120            } else {
121                0.0
122            },
123            avg_latency_ms: if total > 0 {
124                self.metrics.total_latency_ms.load(Ordering::Relaxed) as f64 / total as f64
125            } else {
126                0.0
127            },
128        }
129    }
130
131    /// Embed texts with optimized batching
132    ///
133    /// Splits large batches into optimal chunks for better throughput
134    /// and parallel processing. Returns progress updates via callback.
135    pub async fn embed_batch_optimized<F>(
136        &self,
137        texts: &[String],
138        batch_size: usize,
139        progress_callback: Option<F>,
140    ) -> Result<Vec<Vec<f32>>>
141    where
142        F: Fn(usize, usize) + Send + Sync,
143    {
144        const DEFAULT_BATCH_SIZE: usize = 32;
145        let chunk_size = if batch_size > 0 {
146            batch_size
147        } else {
148            DEFAULT_BATCH_SIZE
149        };
150
151        let total = texts.len();
152        let mut all_results = Vec::with_capacity(total);
153
154        for (chunk_idx, chunk) in texts.chunks(chunk_size).enumerate() {
155            let chunk_results = self.embed_batch(chunk).await?;
156            all_results.extend(chunk_results);
157
158            if let Some(ref callback) = progress_callback {
159                callback((chunk_idx + 1) * chunk_size.min(total), total);
160            }
161        }
162
163        Ok(all_results)
164    }
165
166    /// Embed texts in parallel with multiple concurrent batches
167    ///
168    /// Uses tokio to process multiple batches concurrently for maximum throughput.
169    /// Useful for embedding large document collections.
170    pub async fn embed_batch_parallel(
171        &self,
172        texts: &[String],
173        batch_size: usize,
174        max_concurrent: usize,
175    ) -> Result<Vec<Vec<f32>>> {
176        use futures::stream::StreamExt;
177
178        const DEFAULT_BATCH_SIZE: usize = 32;
179        const DEFAULT_CONCURRENT: usize = 4;
180
181        let chunk_size = if batch_size > 0 {
182            batch_size
183        } else {
184            DEFAULT_BATCH_SIZE
185        };
186        let concurrent = if max_concurrent > 0 {
187            max_concurrent
188        } else {
189            DEFAULT_CONCURRENT
190        };
191
192        let chunks: Vec<_> = texts.chunks(chunk_size).collect();
193        let total_chunks = chunks.len();
194
195        tracing::info!(
196            "Embedding {} texts in {} batches ({} concurrent)",
197            texts.len(),
198            total_chunks,
199            concurrent
200        );
201
202        let results: Vec<_> = stream::iter(chunks)
203            .enumerate()
204            .map(|(idx, chunk)| async move {
205                tracing::debug!("Processing batch {}/{}", idx + 1, total_chunks);
206                let result = self.embed_batch(chunk).await;
207                (idx, result)
208            })
209            .buffer_unordered(concurrent)
210            .collect()
211            .await;
212
213        // Sort results by original order
214        let mut sorted_results: Vec<_> = results;
215        sorted_results.sort_by_key(|(idx, _)| *idx);
216
217        // Flatten results
218        let mut all_embeddings = Vec::with_capacity(texts.len());
219        for (_, result) in sorted_results {
220            all_embeddings.extend(result?);
221        }
222
223        Ok(all_embeddings)
224    }
225}
226
227/// Snapshot of API metrics
228#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct MetricsSnapshot {
230    pub total_requests: u64,
231    pub total_errors: u64,
232    pub cache_hits: u64,
233    pub cache_misses: u64,
234    pub cache_hit_rate: f64,
235    pub avg_latency_ms: f64,
236}
237
238#[async_trait]
239impl LLMClient for VLLMClient {
240    async fn chat_completion(&self, messages: Vec<ChatMessage>) -> Result<String> {
241        use std::sync::atomic::Ordering;
242
243        let start = Instant::now();
244        self.metrics.total_requests.fetch_add(1, Ordering::Relaxed);
245
246        // Check cache first
247        let messages_json = serde_json::to_string(&messages).unwrap_or_default();
248        let cache_key = super::cache::chat_cache_key(&self.config.model, &messages_json);
249
250        if let Some(cached) = self.cache.get(&cache_key) {
251            tracing::debug!("Cache hit for chat completion");
252            self.metrics.cache_hits.fetch_add(1, Ordering::Relaxed);
253            return Ok(cached);
254        }
255
256        self.metrics.cache_misses.fetch_add(1, Ordering::Relaxed);
257
258        #[derive(Serialize)]
259        struct ChatRequest {
260            model: String,
261            messages: Vec<ChatMessage>,
262            temperature: f32,
263            max_tokens: u32,
264        }
265
266        #[derive(Deserialize)]
267        struct ChatResponse {
268            choices: Vec<ChatChoice>,
269        }
270
271        #[derive(Deserialize)]
272        struct ChatChoice {
273            message: ChatMessage,
274        }
275
276        let request = ChatRequest {
277            model: self.config.model.clone(),
278            messages,
279            temperature: 0.7,
280            max_tokens: 512,
281        };
282
283        let url = format!("{}/v1/chat/completions", self.config.url);
284
285        let mut req = self.http_client.post(&url).json(&request);
286
287        if let Some(ref api_key) = self.config.api_key {
288            req = req.header("Authorization", format!("Bearer {}", api_key));
289        }
290
291        let response = req.send().await.map_err(|e| {
292            self.metrics.total_errors.fetch_add(1, Ordering::Relaxed);
293            AgentRootError::Http(e)
294        })?;
295
296        if !response.status().is_success() {
297            self.metrics.total_errors.fetch_add(1, Ordering::Relaxed);
298            let status = response.status();
299            let body = response.text().await.unwrap_or_default();
300            return Err(AgentRootError::ExternalError(format!(
301                "LLM service error (HTTP {}): {}",
302                status, body
303            )));
304        }
305
306        let chat_response: ChatResponse = response.json().await.map_err(|e| {
307            self.metrics.total_errors.fetch_add(1, Ordering::Relaxed);
308            AgentRootError::Http(e)
309        })?;
310
311        let content = chat_response
312            .choices
313            .first()
314            .ok_or_else(|| {
315                self.metrics.total_errors.fetch_add(1, Ordering::Relaxed);
316                AgentRootError::Llm("No response from LLM".to_string())
317            })?
318            .message
319            .content
320            .clone();
321
322        // Cache the response
323        let _ = self.cache.set(cache_key, content.clone());
324
325        // Track latency
326        let elapsed = start.elapsed().as_millis() as u64;
327        self.metrics
328            .total_latency_ms
329            .fetch_add(elapsed, Ordering::Relaxed);
330
331        Ok(content)
332    }
333
334    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
335        let results = self.embed_batch(&[text.to_string()]).await?;
336        results
337            .into_iter()
338            .next()
339            .ok_or_else(|| AgentRootError::Llm("No embedding returned".to_string()))
340    }
341
342    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
343        use std::sync::atomic::Ordering;
344
345        let start = Instant::now();
346        self.metrics.total_requests.fetch_add(1, Ordering::Relaxed);
347
348        // Check cache for each text
349        let mut results = Vec::with_capacity(texts.len());
350        let mut uncached_texts = Vec::new();
351        let mut uncached_indices = Vec::new();
352
353        for (i, text) in texts.iter().enumerate() {
354            let cache_key = super::cache::embedding_cache_key(&self.config.embedding_model, text);
355            if let Some(cached) = self.cache.get(&cache_key) {
356                // Parse cached embedding
357                if let Ok(embedding) = serde_json::from_str::<Vec<f32>>(&cached) {
358                    results.push(Some(embedding));
359                    self.metrics.cache_hits.fetch_add(1, Ordering::Relaxed);
360                    continue;
361                }
362            }
363            self.metrics.cache_misses.fetch_add(1, Ordering::Relaxed);
364            results.push(None);
365            uncached_texts.push(text.clone());
366            uncached_indices.push(i);
367        }
368
369        // If all cached, return early
370        if uncached_texts.is_empty() {
371            tracing::debug!("All {} embeddings from cache", texts.len());
372            return Ok(results.into_iter().map(|r| r.unwrap()).collect());
373        }
374
375        tracing::debug!(
376            "Embedding batch: {} cached, {} to fetch",
377            texts.len() - uncached_texts.len(),
378            uncached_texts.len()
379        );
380
381        #[derive(Serialize)]
382        struct EmbedRequest {
383            model: String,
384            input: Vec<String>,
385        }
386
387        #[derive(Deserialize)]
388        struct EmbedResponse {
389            data: Vec<EmbedData>,
390        }
391
392        #[derive(Deserialize)]
393        struct EmbedData {
394            embedding: Vec<f32>,
395        }
396
397        let request = EmbedRequest {
398            model: self.config.embedding_model.clone(),
399            input: uncached_texts.clone(),
400        };
401
402        let url = format!("{}/v1/embeddings", self.config.embeddings_url());
403
404        let mut req = self.http_client.post(&url).json(&request);
405
406        if let Some(ref api_key) = self.config.api_key {
407            req = req.header("Authorization", format!("Bearer {}", api_key));
408        }
409
410        let response = req.send().await.map_err(|e| {
411            self.metrics.total_errors.fetch_add(1, Ordering::Relaxed);
412            AgentRootError::Http(e)
413        })?;
414
415        if !response.status().is_success() {
416            self.metrics.total_errors.fetch_add(1, Ordering::Relaxed);
417            let status = response.status();
418            let body = response.text().await.unwrap_or_default();
419            return Err(AgentRootError::ExternalError(format!(
420                "Embedding service error (HTTP {}): {}",
421                status, body
422            )));
423        }
424
425        let embed_response: EmbedResponse = response.json().await.map_err(|e| {
426            self.metrics.total_errors.fetch_add(1, Ordering::Relaxed);
427            AgentRootError::Http(e)
428        })?;
429
430        // Fill in uncached results and cache them
431        for (i, embedding) in embed_response.data.into_iter().enumerate() {
432            let original_idx = uncached_indices[i];
433            results[original_idx] = Some(embedding.embedding.clone());
434
435            // Cache the embedding
436            let cache_key =
437                super::cache::embedding_cache_key(&self.config.embedding_model, &uncached_texts[i]);
438            if let Ok(json) = serde_json::to_string(&embedding.embedding) {
439                let _ = self.cache.set(cache_key, json);
440            }
441        }
442
443        // Track latency
444        let elapsed = start.elapsed().as_millis() as u64;
445        self.metrics
446            .total_latency_ms
447            .fetch_add(elapsed, Ordering::Relaxed);
448
449        Ok(results.into_iter().map(|r| r.unwrap()).collect())
450    }
451
452    fn embedding_dimensions(&self) -> usize {
453        self.embedding_dimensions
454    }
455
456    fn model_name(&self) -> &str {
457        &self.config.model
458    }
459}
460
461/// Helper to generate metadata using LLM client
462pub async fn generate_metadata_with_llm(
463    client: &dyn LLMClient,
464    content: &str,
465    context: &MetadataContext,
466) -> Result<DocumentMetadata> {
467    let prompt = build_metadata_prompt(content, context);
468
469    let messages = vec![
470        ChatMessage::system(
471            "You are a metadata generator. Extract structured metadata from documents. \
472             Respond ONLY with valid JSON matching the schema.",
473        ),
474        ChatMessage::user(prompt),
475    ];
476
477    let response = client.chat_completion(messages).await?;
478
479    // Parse JSON response
480    parse_metadata_response(&response)
481}
482
483fn build_metadata_prompt(content: &str, context: &MetadataContext) -> String {
484    // Truncate content if too long (max ~2000 tokens ~8000 chars)
485    let truncated = if content.len() > 8000 {
486        &content[..8000]
487    } else {
488        content
489    };
490
491    format!(
492        r#"Generate metadata for this document:
493
494Source type: {}
495Language: {}
496Collection: {}
497
498Content:
499{}
500
501Output JSON with these fields:
502- summary: 100-200 word summary
503- semantic_title: improved title
504- keywords: 5-10 keywords (array)
505- category: document type
506- intent: purpose description
507- concepts: related concepts (array)
508- difficulty: beginner/intermediate/advanced
509- suggested_queries: search queries (array)
510
511JSON:"#,
512        context.source_type,
513        context.language.as_deref().unwrap_or("unknown"),
514        context.collection_name,
515        truncated
516    )
517}
518
519fn parse_metadata_response(response: &str) -> Result<DocumentMetadata> {
520    // Extract JSON from response (handle markdown code blocks)
521    let json_str = if let Some(start) = response.find('{') {
522        if let Some(end) = response.rfind('}') {
523            &response[start..=end]
524        } else {
525            response
526        }
527    } else {
528        response
529    };
530
531    serde_json::from_str(json_str)
532        .map_err(|e| AgentRootError::Llm(format!("Failed to parse metadata JSON: {}", e)))
533}