1use crate::config::LLMServiceConfig;
4use crate::error::{AgentRootError, Result};
5use crate::llm::{DocumentMetadata, MetadataContext};
6use async_trait::async_trait;
7use futures::stream;
8use serde::{Deserialize, Serialize};
9use std::sync::{atomic::AtomicU64, Arc};
10use std::time::{Duration, Instant};
11
12#[async_trait]
14pub trait LLMClient: Send + Sync {
15 async fn chat_completion(&self, messages: Vec<ChatMessage>) -> Result<String>;
17
18 async fn embed(&self, text: &str) -> Result<Vec<f32>>;
20
21 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
23
24 fn embedding_dimensions(&self) -> usize;
26
27 fn model_name(&self) -> &str;
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct ChatMessage {
34 pub role: String,
35 pub content: String,
36}
37
38impl ChatMessage {
39 pub fn system(content: impl Into<String>) -> Self {
40 Self {
41 role: "system".to_string(),
42 content: content.into(),
43 }
44 }
45
46 pub fn user(content: impl Into<String>) -> Self {
47 Self {
48 role: "user".to_string(),
49 content: content.into(),
50 }
51 }
52}
53
54#[derive(Debug, Default)]
56pub struct APIMetrics {
57 pub total_requests: AtomicU64,
58 pub total_errors: AtomicU64,
59 pub cache_hits: AtomicU64,
60 pub cache_misses: AtomicU64,
61 pub total_latency_ms: AtomicU64,
62}
63
64pub struct VLLMClient {
66 http_client: reqwest::Client,
67 config: LLMServiceConfig,
68 embedding_dimensions: usize,
69 cache: Arc<super::cache::LLMCache>,
70 metrics: Arc<APIMetrics>,
71}
72
73impl VLLMClient {
74 pub fn new(config: LLMServiceConfig) -> Result<Self> {
76 let http_client = reqwest::Client::builder()
77 .timeout(Duration::from_secs(config.timeout_secs))
78 .build()
79 .map_err(AgentRootError::Http)?;
80
81 let embedding_dimensions = config.embedding_dimensions.unwrap_or(384);
83
84 let cache = Arc::new(super::cache::LLMCache::new());
86
87 let metrics = Arc::new(APIMetrics::default());
89
90 Ok(Self {
91 http_client,
92 config,
93 embedding_dimensions,
94 cache,
95 metrics,
96 })
97 }
98
99 pub fn from_env() -> Result<Self> {
101 let config = LLMServiceConfig::default();
102 Self::new(config)
103 }
104
105 pub fn metrics(&self) -> MetricsSnapshot {
107 use std::sync::atomic::Ordering;
108
109 let total = self.metrics.total_requests.load(Ordering::Relaxed);
110 let hits = self.metrics.cache_hits.load(Ordering::Relaxed);
111 let misses = self.metrics.cache_misses.load(Ordering::Relaxed);
112
113 MetricsSnapshot {
114 total_requests: total,
115 total_errors: self.metrics.total_errors.load(Ordering::Relaxed),
116 cache_hits: hits,
117 cache_misses: misses,
118 cache_hit_rate: if total > 0 {
119 hits as f64 / total as f64 * 100.0
120 } else {
121 0.0
122 },
123 avg_latency_ms: if total > 0 {
124 self.metrics.total_latency_ms.load(Ordering::Relaxed) as f64 / total as f64
125 } else {
126 0.0
127 },
128 }
129 }
130
131 pub async fn embed_batch_optimized<F>(
136 &self,
137 texts: &[String],
138 batch_size: usize,
139 progress_callback: Option<F>,
140 ) -> Result<Vec<Vec<f32>>>
141 where
142 F: Fn(usize, usize) + Send + Sync,
143 {
144 const DEFAULT_BATCH_SIZE: usize = 32;
145 let chunk_size = if batch_size > 0 {
146 batch_size
147 } else {
148 DEFAULT_BATCH_SIZE
149 };
150
151 let total = texts.len();
152 let mut all_results = Vec::with_capacity(total);
153
154 for (chunk_idx, chunk) in texts.chunks(chunk_size).enumerate() {
155 let chunk_results = self.embed_batch(chunk).await?;
156 all_results.extend(chunk_results);
157
158 if let Some(ref callback) = progress_callback {
159 callback((chunk_idx + 1) * chunk_size.min(total), total);
160 }
161 }
162
163 Ok(all_results)
164 }
165
166 pub async fn embed_batch_parallel(
171 &self,
172 texts: &[String],
173 batch_size: usize,
174 max_concurrent: usize,
175 ) -> Result<Vec<Vec<f32>>> {
176 use futures::stream::StreamExt;
177
178 const DEFAULT_BATCH_SIZE: usize = 32;
179 const DEFAULT_CONCURRENT: usize = 4;
180
181 let chunk_size = if batch_size > 0 {
182 batch_size
183 } else {
184 DEFAULT_BATCH_SIZE
185 };
186 let concurrent = if max_concurrent > 0 {
187 max_concurrent
188 } else {
189 DEFAULT_CONCURRENT
190 };
191
192 let chunks: Vec<_> = texts.chunks(chunk_size).collect();
193 let total_chunks = chunks.len();
194
195 tracing::info!(
196 "Embedding {} texts in {} batches ({} concurrent)",
197 texts.len(),
198 total_chunks,
199 concurrent
200 );
201
202 let results: Vec<_> = stream::iter(chunks)
203 .enumerate()
204 .map(|(idx, chunk)| async move {
205 tracing::debug!("Processing batch {}/{}", idx + 1, total_chunks);
206 let result = self.embed_batch(chunk).await;
207 (idx, result)
208 })
209 .buffer_unordered(concurrent)
210 .collect()
211 .await;
212
213 let mut sorted_results: Vec<_> = results;
215 sorted_results.sort_by_key(|(idx, _)| *idx);
216
217 let mut all_embeddings = Vec::with_capacity(texts.len());
219 for (_, result) in sorted_results {
220 all_embeddings.extend(result?);
221 }
222
223 Ok(all_embeddings)
224 }
225}
226
227#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct MetricsSnapshot {
230 pub total_requests: u64,
231 pub total_errors: u64,
232 pub cache_hits: u64,
233 pub cache_misses: u64,
234 pub cache_hit_rate: f64,
235 pub avg_latency_ms: f64,
236}
237
238#[async_trait]
239impl LLMClient for VLLMClient {
240 async fn chat_completion(&self, messages: Vec<ChatMessage>) -> Result<String> {
241 use std::sync::atomic::Ordering;
242
243 let start = Instant::now();
244 self.metrics.total_requests.fetch_add(1, Ordering::Relaxed);
245
246 let messages_json = serde_json::to_string(&messages).unwrap_or_default();
248 let cache_key = super::cache::chat_cache_key(&self.config.model, &messages_json);
249
250 if let Some(cached) = self.cache.get(&cache_key) {
251 tracing::debug!("Cache hit for chat completion");
252 self.metrics.cache_hits.fetch_add(1, Ordering::Relaxed);
253 return Ok(cached);
254 }
255
256 self.metrics.cache_misses.fetch_add(1, Ordering::Relaxed);
257
258 #[derive(Serialize)]
259 struct ChatRequest {
260 model: String,
261 messages: Vec<ChatMessage>,
262 temperature: f32,
263 max_tokens: u32,
264 }
265
266 #[derive(Deserialize)]
267 struct ChatResponse {
268 choices: Vec<ChatChoice>,
269 }
270
271 #[derive(Deserialize)]
272 struct ChatChoice {
273 message: ChatMessage,
274 }
275
276 let request = ChatRequest {
277 model: self.config.model.clone(),
278 messages,
279 temperature: 0.7,
280 max_tokens: 512,
281 };
282
283 let url = format!("{}/v1/chat/completions", self.config.url);
284
285 let mut req = self.http_client.post(&url).json(&request);
286
287 if let Some(ref api_key) = self.config.api_key {
288 req = req.header("Authorization", format!("Bearer {}", api_key));
289 }
290
291 let response = req.send().await.map_err(|e| {
292 self.metrics.total_errors.fetch_add(1, Ordering::Relaxed);
293 AgentRootError::Http(e)
294 })?;
295
296 if !response.status().is_success() {
297 self.metrics.total_errors.fetch_add(1, Ordering::Relaxed);
298 let status = response.status();
299 let body = response.text().await.unwrap_or_default();
300 return Err(AgentRootError::ExternalError(format!(
301 "LLM service error (HTTP {}): {}",
302 status, body
303 )));
304 }
305
306 let chat_response: ChatResponse = response.json().await.map_err(|e| {
307 self.metrics.total_errors.fetch_add(1, Ordering::Relaxed);
308 AgentRootError::Http(e)
309 })?;
310
311 let content = chat_response
312 .choices
313 .first()
314 .ok_or_else(|| {
315 self.metrics.total_errors.fetch_add(1, Ordering::Relaxed);
316 AgentRootError::Llm("No response from LLM".to_string())
317 })?
318 .message
319 .content
320 .clone();
321
322 let _ = self.cache.set(cache_key, content.clone());
324
325 let elapsed = start.elapsed().as_millis() as u64;
327 self.metrics
328 .total_latency_ms
329 .fetch_add(elapsed, Ordering::Relaxed);
330
331 Ok(content)
332 }
333
334 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
335 let results = self.embed_batch(&[text.to_string()]).await?;
336 results
337 .into_iter()
338 .next()
339 .ok_or_else(|| AgentRootError::Llm("No embedding returned".to_string()))
340 }
341
342 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
343 use std::sync::atomic::Ordering;
344
345 let start = Instant::now();
346 self.metrics.total_requests.fetch_add(1, Ordering::Relaxed);
347
348 let mut results = Vec::with_capacity(texts.len());
350 let mut uncached_texts = Vec::new();
351 let mut uncached_indices = Vec::new();
352
353 for (i, text) in texts.iter().enumerate() {
354 let cache_key = super::cache::embedding_cache_key(&self.config.embedding_model, text);
355 if let Some(cached) = self.cache.get(&cache_key) {
356 if let Ok(embedding) = serde_json::from_str::<Vec<f32>>(&cached) {
358 results.push(Some(embedding));
359 self.metrics.cache_hits.fetch_add(1, Ordering::Relaxed);
360 continue;
361 }
362 }
363 self.metrics.cache_misses.fetch_add(1, Ordering::Relaxed);
364 results.push(None);
365 uncached_texts.push(text.clone());
366 uncached_indices.push(i);
367 }
368
369 if uncached_texts.is_empty() {
371 tracing::debug!("All {} embeddings from cache", texts.len());
372 return Ok(results.into_iter().map(|r| r.unwrap()).collect());
373 }
374
375 tracing::debug!(
376 "Embedding batch: {} cached, {} to fetch",
377 texts.len() - uncached_texts.len(),
378 uncached_texts.len()
379 );
380
381 #[derive(Serialize)]
382 struct EmbedRequest {
383 model: String,
384 input: Vec<String>,
385 }
386
387 #[derive(Deserialize)]
388 struct EmbedResponse {
389 data: Vec<EmbedData>,
390 }
391
392 #[derive(Deserialize)]
393 struct EmbedData {
394 embedding: Vec<f32>,
395 }
396
397 let request = EmbedRequest {
398 model: self.config.embedding_model.clone(),
399 input: uncached_texts.clone(),
400 };
401
402 let url = format!("{}/v1/embeddings", self.config.embeddings_url());
403
404 let mut req = self.http_client.post(&url).json(&request);
405
406 if let Some(ref api_key) = self.config.api_key {
407 req = req.header("Authorization", format!("Bearer {}", api_key));
408 }
409
410 let response = req.send().await.map_err(|e| {
411 self.metrics.total_errors.fetch_add(1, Ordering::Relaxed);
412 AgentRootError::Http(e)
413 })?;
414
415 if !response.status().is_success() {
416 self.metrics.total_errors.fetch_add(1, Ordering::Relaxed);
417 let status = response.status();
418 let body = response.text().await.unwrap_or_default();
419 return Err(AgentRootError::ExternalError(format!(
420 "Embedding service error (HTTP {}): {}",
421 status, body
422 )));
423 }
424
425 let embed_response: EmbedResponse = response.json().await.map_err(|e| {
426 self.metrics.total_errors.fetch_add(1, Ordering::Relaxed);
427 AgentRootError::Http(e)
428 })?;
429
430 for (i, embedding) in embed_response.data.into_iter().enumerate() {
432 let original_idx = uncached_indices[i];
433 results[original_idx] = Some(embedding.embedding.clone());
434
435 let cache_key =
437 super::cache::embedding_cache_key(&self.config.embedding_model, &uncached_texts[i]);
438 if let Ok(json) = serde_json::to_string(&embedding.embedding) {
439 let _ = self.cache.set(cache_key, json);
440 }
441 }
442
443 let elapsed = start.elapsed().as_millis() as u64;
445 self.metrics
446 .total_latency_ms
447 .fetch_add(elapsed, Ordering::Relaxed);
448
449 Ok(results.into_iter().map(|r| r.unwrap()).collect())
450 }
451
452 fn embedding_dimensions(&self) -> usize {
453 self.embedding_dimensions
454 }
455
456 fn model_name(&self) -> &str {
457 &self.config.model
458 }
459}
460
461pub async fn generate_metadata_with_llm(
463 client: &dyn LLMClient,
464 content: &str,
465 context: &MetadataContext,
466) -> Result<DocumentMetadata> {
467 let prompt = build_metadata_prompt(content, context);
468
469 let messages = vec![
470 ChatMessage::system(
471 "You are a metadata generator. Extract structured metadata from documents. \
472 Respond ONLY with valid JSON matching the schema.",
473 ),
474 ChatMessage::user(prompt),
475 ];
476
477 let response = client.chat_completion(messages).await?;
478
479 parse_metadata_response(&response)
481}
482
483fn build_metadata_prompt(content: &str, context: &MetadataContext) -> String {
484 let truncated = if content.len() > 8000 {
486 &content[..8000]
487 } else {
488 content
489 };
490
491 format!(
492 r#"Generate metadata for this document:
493
494Source type: {}
495Language: {}
496Collection: {}
497
498Content:
499{}
500
501Output JSON with these fields:
502- summary: 100-200 word summary
503- semantic_title: improved title
504- keywords: 5-10 keywords (array)
505- category: document type
506- intent: purpose description
507- concepts: related concepts (array)
508- difficulty: beginner/intermediate/advanced
509- suggested_queries: search queries (array)
510
511JSON:"#,
512 context.source_type,
513 context.language.as_deref().unwrap_or("unknown"),
514 context.collection_name,
515 truncated
516 )
517}
518
519fn parse_metadata_response(response: &str) -> Result<DocumentMetadata> {
520 let json_str = if let Some(start) = response.find('{') {
522 if let Some(end) = response.rfind('}') {
523 &response[start..=end]
524 } else {
525 response
526 }
527 } else {
528 response
529 };
530
531 serde_json::from_str(json_str)
532 .map_err(|e| AgentRootError::Llm(format!("Failed to parse metadata JSON: {}", e)))
533}