1use serde::{Deserialize, Serialize};
33use std::path::PathBuf;
34use std::process::Command;
35
36use crate::error::{Error, Result};
37
38const MAX_INPUT_LENGTH: usize = 10_000;
40
41const MAX_IDENTIFIER_LENGTH: usize = 256;
43
44fn validate_model_name(model: &str) -> Result<&str> {
48 if model.is_empty() {
49 return Err(Error::Validation("Model name cannot be empty".to_string()));
50 }
51 if model.len() > MAX_IDENTIFIER_LENGTH {
52 return Err(Error::Validation(format!(
53 "Model name exceeds maximum length of {} characters",
54 MAX_IDENTIFIER_LENGTH
55 )));
56 }
57 if !model
60 .chars()
61 .all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '/' || c == ':' || c == '.')
62 {
63 return Err(Error::Validation(format!(
64 "Model name contains invalid characters: '{}'. Allowed: alphanumeric, -, _, /, :, .",
65 model
66 )));
67 }
68 if model.contains("..") {
70 return Err(Error::Validation(
71 "Model name cannot contain '..' (path traversal)".to_string(),
72 ));
73 }
74 Ok(model)
75}
76
77fn validate_collection_name(collection: &str) -> Result<&str> {
80 if collection.is_empty() {
81 return Err(Error::Validation(
82 "Collection name cannot be empty".to_string(),
83 ));
84 }
85 if collection.len() > MAX_IDENTIFIER_LENGTH {
86 return Err(Error::Validation(format!(
87 "Collection name exceeds maximum length of {} characters",
88 MAX_IDENTIFIER_LENGTH
89 )));
90 }
91 if !collection
93 .chars()
94 .all(|c| c.is_alphanumeric() || c == '-' || c == '_')
95 {
96 return Err(Error::Validation(format!(
97 "Collection name contains invalid characters: '{}'. Allowed: alphanumeric, -, _",
98 collection
99 )));
100 }
101 Ok(collection)
102}
103
104fn validate_template_name(template: &str) -> Result<&str> {
107 if template.is_empty() {
108 return Err(Error::Validation(
109 "Template name cannot be empty".to_string(),
110 ));
111 }
112 if template.len() > MAX_IDENTIFIER_LENGTH {
113 return Err(Error::Validation(format!(
114 "Template name exceeds maximum length of {} characters",
115 MAX_IDENTIFIER_LENGTH
116 )));
117 }
118 if !template
119 .chars()
120 .all(|c| c.is_alphanumeric() || c == '-' || c == '_')
121 {
122 return Err(Error::Validation(format!(
123 "Template name contains invalid characters: '{}'. Allowed: alphanumeric, -, _",
124 template
125 )));
126 }
127 Ok(template)
128}
129
130fn format_dangerous_char(c: char) -> String {
132 match c {
133 '\n' => "\\n".to_string(),
134 '\r' => "\\r".to_string(),
135 '\0' => "\\0".to_string(),
136 _ => c.to_string(),
137 }
138}
139
140fn validate_db_path(path: &str) -> Result<&str> {
143 if path.is_empty() {
144 return Err(Error::Validation(
145 "Database path cannot be empty".to_string(),
146 ));
147 }
148 if path.len() > 4096 {
149 return Err(Error::Validation(
150 "Database path exceeds maximum length".to_string(),
151 ));
152 }
153 let dangerous_chars = [
156 '$', '`', '!', '&', '|', ';', '(', ')', '{', '}', '<', '>', '\n', '\r', '\0', '"', '\'',
157 '\\',
158 ];
159 for c in dangerous_chars {
160 if path.contains(c) {
161 return Err(Error::Validation(format!(
162 "Database path contains dangerous character: '{}'",
163 format_dangerous_char(c)
164 )));
165 }
166 }
167 if path.contains("..") {
169 return Err(Error::Validation(
170 "Database path cannot contain '..' (path traversal)".to_string(),
171 ));
172 }
173 Ok(path)
174}
175
176fn validate_user_input(input: &str) -> Result<&str> {
179 if input.len() > MAX_INPUT_LENGTH {
180 return Err(Error::Validation(format!(
181 "Input exceeds maximum length of {} characters",
182 MAX_INPUT_LENGTH
183 )));
184 }
185 Ok(input)
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct LlmCliConfig {
191 pub binary_path: PathBuf,
193 pub default_model: Option<String>,
195 pub default_embedding_model: Option<String>,
197 pub database_path: Option<PathBuf>,
199}
200
201impl Default for LlmCliConfig {
202 fn default() -> Self {
203 Self {
204 binary_path: PathBuf::from("llm"),
205 default_model: None,
206 default_embedding_model: None,
207 database_path: None,
208 }
209 }
210}
211
212#[derive(Debug, Clone, Default, Serialize, Deserialize)]
214pub struct EmbeddingConfig {
215 pub model: Option<String>,
217 pub collection: Option<String>,
219 pub database: Option<PathBuf>,
221 pub store_metadata: bool,
223}
224
225#[derive(Debug, Clone, Default, Serialize, Deserialize)]
227pub struct ClusterConfig {
228 pub model: Option<String>,
230 pub format: Option<String>,
232 pub summary: bool,
234}
235
236#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct RagConfig {
239 pub collection: String,
241 pub top_k: usize,
243 pub model: Option<String>,
245 pub system_prompt: Option<String>,
247}
248
249impl Default for RagConfig {
250 fn default() -> Self {
251 Self {
252 collection: "default".to_string(),
253 top_k: 5,
254 model: None,
255 system_prompt: None,
256 }
257 }
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct PromptResult {
263 pub response: String,
265 pub model: String,
267 pub tokens_used: Option<u64>,
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
273pub struct EmbeddingResult {
274 pub embedding: Vec<f32>,
276 pub text: String,
278 pub model: String,
280}
281
282#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct ClusterResult {
285 pub clusters: Vec<ClusterAssignment>,
287 pub summaries: Option<Vec<String>>,
289}
290
291#[derive(Debug, Clone, Serialize, Deserialize)]
293pub struct ClusterAssignment {
294 pub id: String,
296 pub cluster: usize,
298 pub score: Option<f64>,
300}
301
302#[derive(Debug, Clone, Serialize, Deserialize)]
304pub struct SimilarityResult {
305 pub matches: Vec<SimilarityMatch>,
307}
308
309#[derive(Debug, Clone, Serialize, Deserialize)]
311pub struct SimilarityMatch {
312 pub content: String,
314 pub score: f64,
316 pub id: Option<String>,
318 pub metadata: Option<serde_json::Value>,
320}
321
322#[derive(Debug, Clone)]
324pub struct LlmCliClient {
325 config: LlmCliConfig,
326}
327
328impl LlmCliClient {
329 pub fn new() -> Result<Self> {
331 Ok(Self {
332 config: LlmCliConfig::default(),
333 })
334 }
335
336 pub fn with_config(config: LlmCliConfig) -> Self {
338 Self { config }
339 }
340
341 pub fn is_available(&self) -> bool {
343 Command::new(&self.config.binary_path)
344 .arg("--version")
345 .output()
346 .map(|o| o.status.success())
347 .unwrap_or(false)
348 }
349
350 pub fn list_models(&self) -> Result<Vec<String>> {
352 let output = Command::new(&self.config.binary_path)
353 .arg("models")
354 .arg("list")
355 .output()
356 .map_err(Error::Io)?;
357
358 if !output.status.success() {
359 return Err(Error::Io(std::io::Error::other(
360 String::from_utf8_lossy(&output.stderr).to_string(),
361 )));
362 }
363
364 let stdout = String::from_utf8_lossy(&output.stdout);
365 Ok(stdout.lines().map(String::from).collect())
366 }
367
368 pub fn prompt(&self, text: &str, model: Option<&str>) -> Result<PromptResult> {
370 let validated_text = validate_user_input(text)?;
372 let validated_model = match model {
373 Some(m) => Some(validate_model_name(m)?),
374 None => None,
375 };
376
377 let mut cmd = Command::new(&self.config.binary_path);
378
379 if let Some(m) = validated_model.or(self.config.default_model.as_deref()) {
380 cmd.arg("-m").arg(m);
381 }
382
383 cmd.arg(validated_text);
384
385 let output = cmd.output().map_err(Error::Io)?;
386
387 if !output.status.success() {
388 return Err(Error::Io(std::io::Error::other(
389 String::from_utf8_lossy(&output.stderr).to_string(),
390 )));
391 }
392
393 Ok(PromptResult {
394 response: String::from_utf8_lossy(&output.stdout).to_string(),
395 model: validated_model
396 .or(self.config.default_model.as_deref())
397 .unwrap_or("default")
398 .to_string(),
399 tokens_used: None,
400 })
401 }
402
403 pub fn prompt_with_system(
405 &self,
406 text: &str,
407 system: &str,
408 model: Option<&str>,
409 ) -> Result<PromptResult> {
410 let validated_text = validate_user_input(text)?;
412 let validated_system = validate_user_input(system)?;
413 let validated_model = match model {
414 Some(m) => Some(validate_model_name(m)?),
415 None => None,
416 };
417
418 let mut cmd = Command::new(&self.config.binary_path);
419
420 if let Some(m) = validated_model.or(self.config.default_model.as_deref()) {
421 cmd.arg("-m").arg(m);
422 }
423
424 cmd.arg("-s").arg(validated_system);
425 cmd.arg(validated_text);
426
427 let output = cmd.output().map_err(Error::Io)?;
428
429 if !output.status.success() {
430 return Err(Error::Io(std::io::Error::other(
431 String::from_utf8_lossy(&output.stderr).to_string(),
432 )));
433 }
434
435 Ok(PromptResult {
436 response: String::from_utf8_lossy(&output.stdout).to_string(),
437 model: validated_model
438 .or(self.config.default_model.as_deref())
439 .unwrap_or("default")
440 .to_string(),
441 tokens_used: None,
442 })
443 }
444
445 pub fn embed(&self, text: &str, config: Option<&EmbeddingConfig>) -> Result<EmbeddingResult> {
447 let validated_text = validate_user_input(text)?;
449
450 let mut cmd = Command::new(&self.config.binary_path);
451 cmd.arg("embed");
452
453 if let Some(cfg) = config {
455 if let Some(ref m) = cfg.model {
456 let validated = validate_model_name(m)?;
457 cmd.arg("-m").arg(validated);
458 }
459 if let Some(ref c) = cfg.collection {
460 let validated = validate_collection_name(c)?;
461 cmd.arg("-c").arg(validated);
462 }
463 if let Some(ref db) = cfg.database {
464 let db_str = db.to_string_lossy();
465 let validated = validate_db_path(&db_str)?;
466 cmd.arg("-d").arg(validated);
467 }
468 }
469
470 cmd.arg(validated_text);
471
472 let output = cmd.output().map_err(Error::Io)?;
473
474 if !output.status.success() {
475 return Err(Error::Io(std::io::Error::other(
476 String::from_utf8_lossy(&output.stderr).to_string(),
477 )));
478 }
479
480 let stdout = String::from_utf8_lossy(&output.stdout);
482 let embedding: Vec<f32> = serde_json::from_str(&stdout).map_err(|e| {
483 Error::Io(std::io::Error::new(
484 std::io::ErrorKind::InvalidData,
485 format!(
486 "Failed to parse embedding response: {}. Response: {}",
487 e, stdout
488 ),
489 ))
490 })?;
491
492 Ok(EmbeddingResult {
493 embedding,
494 text: text.to_string(),
495 model: config
496 .and_then(|c| c.model.clone())
497 .or_else(|| self.config.default_embedding_model.clone())
498 .unwrap_or_else(|| "default".to_string()),
499 })
500 }
501
502 pub fn cluster(
504 &self,
505 input: &str,
506 num_clusters: usize,
507 config: Option<&ClusterConfig>,
508 ) -> Result<ClusterResult> {
509 let validated_input = validate_user_input(input)?;
511
512 let mut cmd = Command::new(&self.config.binary_path);
513 cmd.arg("cluster");
514 cmd.arg("-n").arg(num_clusters.to_string());
515
516 if let Some(cfg) = config {
517 if let Some(ref m) = cfg.model {
518 let validated = validate_model_name(m)?;
519 cmd.arg("-m").arg(validated);
520 }
521 if cfg.summary {
522 cmd.arg("--summary");
523 }
524 }
525
526 cmd.arg(validated_input);
527
528 let output = cmd.output().map_err(Error::Io)?;
529
530 if !output.status.success() {
531 return Err(Error::Io(std::io::Error::other(
532 String::from_utf8_lossy(&output.stderr).to_string(),
533 )));
534 }
535
536 Ok(ClusterResult {
538 clusters: vec![],
539 summaries: None,
540 })
541 }
542
543 pub fn similar(&self, query: &str, collection: &str, top_k: usize) -> Result<SimilarityResult> {
545 let validated_query = validate_user_input(query)?;
547 let validated_collection = validate_collection_name(collection)?;
548
549 let mut cmd = Command::new(&self.config.binary_path);
550 cmd.arg("similar");
551 cmd.arg("-c").arg(validated_collection);
552 cmd.arg("-n").arg(top_k.to_string());
553 cmd.arg(validated_query);
554
555 let output = cmd.output().map_err(Error::Io)?;
556
557 if !output.status.success() {
558 return Err(Error::Io(std::io::Error::other(
559 String::from_utf8_lossy(&output.stderr).to_string(),
560 )));
561 }
562
563 Ok(SimilarityResult { matches: vec![] })
565 }
566
567 pub fn rag(&self, query: &str, config: &RagConfig) -> Result<PromptResult> {
569 let similar = self.similar(query, &config.collection, config.top_k)?;
571
572 let context: String = similar
574 .matches
575 .iter()
576 .map(|m| m.content.clone())
577 .collect::<Vec<_>>()
578 .join("\n\n---\n\n");
579
580 let rag_prompt = format!(
582 "Based on the following context, answer the question.\n\n\
583 Context:\n{}\n\n\
584 Question: {}",
585 context, query
586 );
587
588 if let Some(ref system) = config.system_prompt {
590 self.prompt_with_system(&rag_prompt, system, config.model.as_deref())
591 } else {
592 self.prompt(&rag_prompt, config.model.as_deref())
593 }
594 }
595
596 pub fn prompt_with_template(
598 &self,
599 template: &str,
600 variables: &[(&str, &str)],
601 model: Option<&str>,
602 ) -> Result<PromptResult> {
603 let validated_template = validate_template_name(template)?;
605 let validated_model = match model {
606 Some(m) => Some(validate_model_name(m)?),
607 None => None,
608 };
609
610 let mut cmd = Command::new(&self.config.binary_path);
611 cmd.arg("-t").arg(validated_template);
612
613 if let Some(m) = validated_model.or(self.config.default_model.as_deref()) {
614 cmd.arg("-m").arg(m);
615 }
616
617 for (key, value) in variables {
619 if !key
621 .chars()
622 .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
623 {
624 return Err(Error::Validation(format!(
625 "Template variable name contains invalid characters: '{}'",
626 key
627 )));
628 }
629 let validated_value = validate_user_input(value)?;
631 cmd.arg("-p").arg(*key).arg(validated_value);
632 }
633
634 let output = cmd.output().map_err(Error::Io)?;
635
636 if !output.status.success() {
637 return Err(Error::Io(std::io::Error::other(
638 String::from_utf8_lossy(&output.stderr).to_string(),
639 )));
640 }
641
642 Ok(PromptResult {
643 response: String::from_utf8_lossy(&output.stdout).to_string(),
644 model: validated_model
645 .or(self.config.default_model.as_deref())
646 .unwrap_or("default")
647 .to_string(),
648 tokens_used: None,
649 })
650 }
651}
652
653impl Default for LlmCliClient {
654 fn default() -> Self {
655 Self::new().expect("Failed to create default LlmCliClient")
656 }
657}
658
659#[cfg(test)]
660mod tests {
661 use super::*;
662
663 #[test]
664 fn test_validate_model_name_valid() {
665 assert!(validate_model_name("gpt-4o-mini").is_ok());
666 assert!(validate_model_name("claude-sonnet-4").is_ok());
667 assert!(validate_model_name("sentence-transformers/all-MiniLM-L6-v2").is_ok());
668 assert!(validate_model_name("model:latest").is_ok());
669 }
670
671 #[test]
672 fn test_validate_model_name_invalid() {
673 assert!(validate_model_name("").is_err());
674 assert!(validate_model_name("model; rm -rf /").is_err());
675 assert!(validate_model_name("model$(whoami)").is_err());
676 assert!(validate_model_name("../../../etc/passwd").is_err());
677 }
678
679 #[test]
680 fn test_validate_collection_name_valid() {
681 assert!(validate_collection_name("my-collection").is_ok());
682 assert!(validate_collection_name("collection_123").is_ok());
683 }
684
685 #[test]
686 fn test_validate_collection_name_invalid() {
687 assert!(validate_collection_name("").is_err());
688 assert!(validate_collection_name("collection/path").is_err());
689 assert!(validate_collection_name("col; drop table").is_err());
690 }
691
692 #[test]
693 fn test_validate_db_path_valid() {
694 assert!(validate_db_path("/home/user/.llm/db.sqlite").is_ok());
695 assert!(validate_db_path("./data/embeddings.db").is_ok());
696 }
697
698 #[test]
699 fn test_validate_db_path_invalid() {
700 assert!(validate_db_path("").is_err());
701 assert!(validate_db_path("/path/../../../etc/passwd").is_err());
702 assert!(validate_db_path("/path$(whoami)/db").is_err());
703 assert!(validate_db_path("/path`id`/db").is_err());
704 assert!(validate_db_path("/path;rm -rf/db").is_err());
705 }
706
707 #[test]
708 fn test_format_dangerous_char() {
709 assert_eq!(format_dangerous_char('\n'), "\\n");
710 assert_eq!(format_dangerous_char('\r'), "\\r");
711 assert_eq!(format_dangerous_char('\0'), "\\0");
712 assert_eq!(format_dangerous_char('$'), "$");
713 }
714
715 #[test]
716 fn test_client_creation() {
717 let client = LlmCliClient::new();
718 assert!(client.is_ok());
719 }
720
721 #[test]
722 fn test_config_default() {
723 let config = LlmCliConfig::default();
724 assert_eq!(config.binary_path, PathBuf::from("llm"));
725 assert!(config.default_model.is_none());
726 }
727}