reasonkit/thinktool/
llm_cli.rs

1//! # LLM CLI Ecosystem Integration
2//!
3//! Integrates Simon Willison's LLM CLI tool for multi-model orchestration,
4//! embeddings, clustering, and RAG pipelines.
5//!
6//! ## Usage
7//!
8//! ```rust,ignore
9//! use reasonkit::thinktool::llm_cli::{LlmCliClient, EmbeddingConfig, ClusterConfig};
10//!
11//! let client = LlmCliClient::new()?;
12//!
13//! // Execute a prompt
14//! let response = client.prompt("Analyze this code", Some("claude-sonnet-4")).await?;
15//!
16//! // Generate embeddings
17//! let embeddings = client.embed("text to embed", None).await?;
18//!
19//! // Cluster documents
20//! let clusters = client.cluster("documents", 5, None).await?;
21//! ```
22//!
23//! ## Security
24//!
25//! All user-provided inputs are validated and sanitized before being passed to
26//! shell commands. This includes:
27//! - Model names (alphanumeric, hyphens, underscores, slashes, colons, dots)
28//! - Collection names (alphanumeric, hyphens, underscores)
29//! - Database paths (validated as safe filesystem paths)
30//! - Template names (alphanumeric, hyphens, underscores)
31
32use serde::{Deserialize, Serialize};
33use std::path::PathBuf;
34use std::process::Command;
35
36use crate::error::{Error, Result};
37
38/// Maximum length for user-provided string inputs to prevent DoS
39const MAX_INPUT_LENGTH: usize = 10_000;
40
41/// Maximum length for identifiers (model names, collection names, etc.)
42const MAX_IDENTIFIER_LENGTH: usize = 256;
43
44/// Validate and sanitize a model name.
45/// Allowed characters: alphanumeric, hyphens, underscores, slashes, colons, dots.
46/// Examples: "gpt-4o-mini", "claude-sonnet-4", "sentence-transformers/all-MiniLM-L6-v2"
47fn validate_model_name(model: &str) -> Result<&str> {
48    if model.is_empty() {
49        return Err(Error::Validation("Model name cannot be empty".to_string()));
50    }
51    if model.len() > MAX_IDENTIFIER_LENGTH {
52        return Err(Error::Validation(format!(
53            "Model name exceeds maximum length of {} characters",
54            MAX_IDENTIFIER_LENGTH
55        )));
56    }
57    // Allow alphanumeric, hyphens, underscores, slashes, colons, dots
58    // This covers: gpt-4o-mini, claude-3, sentence-transformers/all-MiniLM-L6-v2
59    if !model
60        .chars()
61        .all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '/' || c == ':' || c == '.')
62    {
63        return Err(Error::Validation(format!(
64            "Model name contains invalid characters: '{}'. Allowed: alphanumeric, -, _, /, :, .",
65            model
66        )));
67    }
68    // Prevent path traversal attempts
69    if model.contains("..") {
70        return Err(Error::Validation(
71            "Model name cannot contain '..' (path traversal)".to_string(),
72        ));
73    }
74    Ok(model)
75}
76
77/// Validate a collection name.
78/// Allowed characters: alphanumeric, hyphens, underscores.
79fn validate_collection_name(collection: &str) -> Result<&str> {
80    if collection.is_empty() {
81        return Err(Error::Validation(
82            "Collection name cannot be empty".to_string(),
83        ));
84    }
85    if collection.len() > MAX_IDENTIFIER_LENGTH {
86        return Err(Error::Validation(format!(
87            "Collection name exceeds maximum length of {} characters",
88            MAX_IDENTIFIER_LENGTH
89        )));
90    }
91    // Collection names should be simple identifiers
92    if !collection
93        .chars()
94        .all(|c| c.is_alphanumeric() || c == '-' || c == '_')
95    {
96        return Err(Error::Validation(format!(
97            "Collection name contains invalid characters: '{}'. Allowed: alphanumeric, -, _",
98            collection
99        )));
100    }
101    Ok(collection)
102}
103
104/// Validate a template name.
105/// Allowed characters: alphanumeric, hyphens, underscores.
106fn validate_template_name(template: &str) -> Result<&str> {
107    if template.is_empty() {
108        return Err(Error::Validation(
109            "Template name cannot be empty".to_string(),
110        ));
111    }
112    if template.len() > MAX_IDENTIFIER_LENGTH {
113        return Err(Error::Validation(format!(
114            "Template name exceeds maximum length of {} characters",
115            MAX_IDENTIFIER_LENGTH
116        )));
117    }
118    if !template
119        .chars()
120        .all(|c| c.is_alphanumeric() || c == '-' || c == '_')
121    {
122        return Err(Error::Validation(format!(
123            "Template name contains invalid characters: '{}'. Allowed: alphanumeric, -, _",
124            template
125        )));
126    }
127    Ok(template)
128}
129
130/// Format a dangerous character for display in error messages.
131fn format_dangerous_char(c: char) -> String {
132    match c {
133        '\n' => "\\n".to_string(),
134        '\r' => "\\r".to_string(),
135        '\0' => "\\0".to_string(),
136        _ => c.to_string(),
137    }
138}
139
140/// Validate a database path.
141/// Must be a valid path without shell metacharacters.
142fn validate_db_path(path: &str) -> Result<&str> {
143    if path.is_empty() {
144        return Err(Error::Validation(
145            "Database path cannot be empty".to_string(),
146        ));
147    }
148    if path.len() > 4096 {
149        return Err(Error::Validation(
150            "Database path exceeds maximum length".to_string(),
151        ));
152    }
153    // Reject shell metacharacters and control characters
154    // Safe characters: alphanumeric, path separators, dots, hyphens, underscores
155    let dangerous_chars = [
156        '$', '`', '!', '&', '|', ';', '(', ')', '{', '}', '<', '>', '\n', '\r', '\0', '"', '\'',
157        '\\',
158    ];
159    for c in dangerous_chars {
160        if path.contains(c) {
161            return Err(Error::Validation(format!(
162                "Database path contains dangerous character: '{}'",
163                format_dangerous_char(c)
164            )));
165        }
166    }
167    // Prevent path traversal
168    if path.contains("..") {
169        return Err(Error::Validation(
170            "Database path cannot contain '..' (path traversal)".to_string(),
171        ));
172    }
173    Ok(path)
174}
175
176/// Validate user input text (prompts, system messages).
177/// Limits length to prevent DoS attacks.
178fn validate_user_input(input: &str) -> Result<&str> {
179    if input.len() > MAX_INPUT_LENGTH {
180        return Err(Error::Validation(format!(
181            "Input exceeds maximum length of {} characters",
182            MAX_INPUT_LENGTH
183        )));
184    }
185    Ok(input)
186}
187
188/// Configuration for the LLM CLI client
189#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct LlmCliConfig {
191    /// Path to the LLM CLI binary (default: "llm")
192    pub binary_path: PathBuf,
193    /// Default model to use
194    pub default_model: Option<String>,
195    /// Default embedding model
196    pub default_embedding_model: Option<String>,
197    /// Database path for embeddings
198    pub database_path: Option<PathBuf>,
199}
200
201impl Default for LlmCliConfig {
202    fn default() -> Self {
203        Self {
204            binary_path: PathBuf::from("llm"),
205            default_model: None,
206            default_embedding_model: None,
207            database_path: None,
208        }
209    }
210}
211
212/// Configuration for embedding operations
213#[derive(Debug, Clone, Default, Serialize, Deserialize)]
214pub struct EmbeddingConfig {
215    /// Embedding model to use
216    pub model: Option<String>,
217    /// Collection name for storing embeddings
218    pub collection: Option<String>,
219    /// Database path
220    pub database: Option<PathBuf>,
221    /// Store metadata with embeddings
222    pub store_metadata: bool,
223}
224
225/// Configuration for clustering operations
226#[derive(Debug, Clone, Default, Serialize, Deserialize)]
227pub struct ClusterConfig {
228    /// Embedding model for clustering
229    pub model: Option<String>,
230    /// Output format (json, csv, etc.)
231    pub format: Option<String>,
232    /// Include summary in output
233    pub summary: bool,
234}
235
236/// Configuration for RAG operations
237#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct RagConfig {
239    /// Collection to search
240    pub collection: String,
241    /// Number of results to retrieve
242    pub top_k: usize,
243    /// Model for generation
244    pub model: Option<String>,
245    /// System prompt for RAG
246    pub system_prompt: Option<String>,
247}
248
249impl Default for RagConfig {
250    fn default() -> Self {
251        Self {
252            collection: "default".to_string(),
253            top_k: 5,
254            model: None,
255            system_prompt: None,
256        }
257    }
258}
259
260/// Result from an LLM prompt
261#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct PromptResult {
263    /// The generated response
264    pub response: String,
265    /// Model used
266    pub model: String,
267    /// Tokens used (if available)
268    pub tokens_used: Option<u64>,
269}
270
271/// Result from embedding operation
272#[derive(Debug, Clone, Serialize, Deserialize)]
273pub struct EmbeddingResult {
274    /// The embedding vector
275    pub embedding: Vec<f32>,
276    /// Text that was embedded
277    pub text: String,
278    /// Model used
279    pub model: String,
280}
281
282/// Result from clustering operation
283#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct ClusterResult {
285    /// Cluster assignments
286    pub clusters: Vec<ClusterAssignment>,
287    /// Cluster summaries (if requested)
288    pub summaries: Option<Vec<String>>,
289}
290
291/// Individual cluster assignment
292#[derive(Debug, Clone, Serialize, Deserialize)]
293pub struct ClusterAssignment {
294    /// Document ID or index
295    pub id: String,
296    /// Assigned cluster
297    pub cluster: usize,
298    /// Similarity score within cluster
299    pub score: Option<f64>,
300}
301
302/// Result from similarity search
303#[derive(Debug, Clone, Serialize, Deserialize)]
304pub struct SimilarityResult {
305    /// Matching documents
306    pub matches: Vec<SimilarityMatch>,
307}
308
309/// Individual similarity match
310#[derive(Debug, Clone, Serialize, Deserialize)]
311pub struct SimilarityMatch {
312    /// Document content
313    pub content: String,
314    /// Similarity score
315    pub score: f64,
316    /// Document ID
317    pub id: Option<String>,
318    /// Metadata
319    pub metadata: Option<serde_json::Value>,
320}
321
322/// LLM CLI Client for multi-model orchestration
323#[derive(Debug, Clone)]
324pub struct LlmCliClient {
325    config: LlmCliConfig,
326}
327
328impl LlmCliClient {
329    /// Create a new LLM CLI client with default configuration
330    pub fn new() -> Result<Self> {
331        Ok(Self {
332            config: LlmCliConfig::default(),
333        })
334    }
335
336    /// Create a new LLM CLI client with custom configuration
337    pub fn with_config(config: LlmCliConfig) -> Self {
338        Self { config }
339    }
340
341    /// Check if the LLM CLI is available
342    pub fn is_available(&self) -> bool {
343        Command::new(&self.config.binary_path)
344            .arg("--version")
345            .output()
346            .map(|o| o.status.success())
347            .unwrap_or(false)
348    }
349
350    /// List available models
351    pub fn list_models(&self) -> Result<Vec<String>> {
352        let output = Command::new(&self.config.binary_path)
353            .arg("models")
354            .arg("list")
355            .output()
356            .map_err(Error::Io)?;
357
358        if !output.status.success() {
359            return Err(Error::Io(std::io::Error::other(
360                String::from_utf8_lossy(&output.stderr).to_string(),
361            )));
362        }
363
364        let stdout = String::from_utf8_lossy(&output.stdout);
365        Ok(stdout.lines().map(String::from).collect())
366    }
367
368    /// Execute a prompt with the LLM CLI
369    pub fn prompt(&self, text: &str, model: Option<&str>) -> Result<PromptResult> {
370        // Validate inputs
371        let validated_text = validate_user_input(text)?;
372        let validated_model = match model {
373            Some(m) => Some(validate_model_name(m)?),
374            None => None,
375        };
376
377        let mut cmd = Command::new(&self.config.binary_path);
378
379        if let Some(m) = validated_model.or(self.config.default_model.as_deref()) {
380            cmd.arg("-m").arg(m);
381        }
382
383        cmd.arg(validated_text);
384
385        let output = cmd.output().map_err(Error::Io)?;
386
387        if !output.status.success() {
388            return Err(Error::Io(std::io::Error::other(
389                String::from_utf8_lossy(&output.stderr).to_string(),
390            )));
391        }
392
393        Ok(PromptResult {
394            response: String::from_utf8_lossy(&output.stdout).to_string(),
395            model: validated_model
396                .or(self.config.default_model.as_deref())
397                .unwrap_or("default")
398                .to_string(),
399            tokens_used: None,
400        })
401    }
402
403    /// Execute a prompt with system message
404    pub fn prompt_with_system(
405        &self,
406        text: &str,
407        system: &str,
408        model: Option<&str>,
409    ) -> Result<PromptResult> {
410        // Validate inputs
411        let validated_text = validate_user_input(text)?;
412        let validated_system = validate_user_input(system)?;
413        let validated_model = match model {
414            Some(m) => Some(validate_model_name(m)?),
415            None => None,
416        };
417
418        let mut cmd = Command::new(&self.config.binary_path);
419
420        if let Some(m) = validated_model.or(self.config.default_model.as_deref()) {
421            cmd.arg("-m").arg(m);
422        }
423
424        cmd.arg("-s").arg(validated_system);
425        cmd.arg(validated_text);
426
427        let output = cmd.output().map_err(Error::Io)?;
428
429        if !output.status.success() {
430            return Err(Error::Io(std::io::Error::other(
431                String::from_utf8_lossy(&output.stderr).to_string(),
432            )));
433        }
434
435        Ok(PromptResult {
436            response: String::from_utf8_lossy(&output.stdout).to_string(),
437            model: validated_model
438                .or(self.config.default_model.as_deref())
439                .unwrap_or("default")
440                .to_string(),
441            tokens_used: None,
442        })
443    }
444
445    /// Generate embeddings for text
446    pub fn embed(&self, text: &str, config: Option<&EmbeddingConfig>) -> Result<EmbeddingResult> {
447        // Validate inputs
448        let validated_text = validate_user_input(text)?;
449
450        let mut cmd = Command::new(&self.config.binary_path);
451        cmd.arg("embed");
452
453        // Validate and add optional config
454        if let Some(cfg) = config {
455            if let Some(ref m) = cfg.model {
456                let validated = validate_model_name(m)?;
457                cmd.arg("-m").arg(validated);
458            }
459            if let Some(ref c) = cfg.collection {
460                let validated = validate_collection_name(c)?;
461                cmd.arg("-c").arg(validated);
462            }
463            if let Some(ref db) = cfg.database {
464                let db_str = db.to_string_lossy();
465                let validated = validate_db_path(&db_str)?;
466                cmd.arg("-d").arg(validated);
467            }
468        }
469
470        cmd.arg(validated_text);
471
472        let output = cmd.output().map_err(Error::Io)?;
473
474        if !output.status.success() {
475            return Err(Error::Io(std::io::Error::other(
476                String::from_utf8_lossy(&output.stderr).to_string(),
477            )));
478        }
479
480        // Parse JSON output
481        let stdout = String::from_utf8_lossy(&output.stdout);
482        let embedding: Vec<f32> = serde_json::from_str(&stdout).map_err(|e| {
483            Error::Io(std::io::Error::new(
484                std::io::ErrorKind::InvalidData,
485                format!(
486                    "Failed to parse embedding response: {}. Response: {}",
487                    e, stdout
488                ),
489            ))
490        })?;
491
492        Ok(EmbeddingResult {
493            embedding,
494            text: text.to_string(),
495            model: config
496                .and_then(|c| c.model.clone())
497                .or_else(|| self.config.default_embedding_model.clone())
498                .unwrap_or_else(|| "default".to_string()),
499        })
500    }
501
502    /// Cluster documents
503    pub fn cluster(
504        &self,
505        input: &str,
506        num_clusters: usize,
507        config: Option<&ClusterConfig>,
508    ) -> Result<ClusterResult> {
509        // Validate inputs
510        let validated_input = validate_user_input(input)?;
511
512        let mut cmd = Command::new(&self.config.binary_path);
513        cmd.arg("cluster");
514        cmd.arg("-n").arg(num_clusters.to_string());
515
516        if let Some(cfg) = config {
517            if let Some(ref m) = cfg.model {
518                let validated = validate_model_name(m)?;
519                cmd.arg("-m").arg(validated);
520            }
521            if cfg.summary {
522                cmd.arg("--summary");
523            }
524        }
525
526        cmd.arg(validated_input);
527
528        let output = cmd.output().map_err(Error::Io)?;
529
530        if !output.status.success() {
531            return Err(Error::Io(std::io::Error::other(
532                String::from_utf8_lossy(&output.stderr).to_string(),
533            )));
534        }
535
536        // Parse output - simplified for now
537        Ok(ClusterResult {
538            clusters: vec![],
539            summaries: None,
540        })
541    }
542
543    /// Perform similarity search
544    pub fn similar(&self, query: &str, collection: &str, top_k: usize) -> Result<SimilarityResult> {
545        // Validate inputs
546        let validated_query = validate_user_input(query)?;
547        let validated_collection = validate_collection_name(collection)?;
548
549        let mut cmd = Command::new(&self.config.binary_path);
550        cmd.arg("similar");
551        cmd.arg("-c").arg(validated_collection);
552        cmd.arg("-n").arg(top_k.to_string());
553        cmd.arg(validated_query);
554
555        let output = cmd.output().map_err(Error::Io)?;
556
557        if !output.status.success() {
558            return Err(Error::Io(std::io::Error::other(
559                String::from_utf8_lossy(&output.stderr).to_string(),
560            )));
561        }
562
563        // Parse output - simplified for now
564        Ok(SimilarityResult { matches: vec![] })
565    }
566
567    /// Execute a RAG pipeline
568    pub fn rag(&self, query: &str, config: &RagConfig) -> Result<PromptResult> {
569        // First, get similar documents
570        let similar = self.similar(query, &config.collection, config.top_k)?;
571
572        // Build context from matches
573        let context: String = similar
574            .matches
575            .iter()
576            .map(|m| m.content.clone())
577            .collect::<Vec<_>>()
578            .join("\n\n---\n\n");
579
580        // Build RAG prompt
581        let rag_prompt = format!(
582            "Based on the following context, answer the question.\n\n\
583             Context:\n{}\n\n\
584             Question: {}",
585            context, query
586        );
587
588        // Execute with optional system prompt
589        if let Some(ref system) = config.system_prompt {
590            self.prompt_with_system(&rag_prompt, system, config.model.as_deref())
591        } else {
592            self.prompt(&rag_prompt, config.model.as_deref())
593        }
594    }
595
596    /// Execute a prompt using a template
597    pub fn prompt_with_template(
598        &self,
599        template: &str,
600        variables: &[(&str, &str)],
601        model: Option<&str>,
602    ) -> Result<PromptResult> {
603        // Validate template name
604        let validated_template = validate_template_name(template)?;
605        let validated_model = match model {
606            Some(m) => Some(validate_model_name(m)?),
607            None => None,
608        };
609
610        let mut cmd = Command::new(&self.config.binary_path);
611        cmd.arg("-t").arg(validated_template);
612
613        if let Some(m) = validated_model.or(self.config.default_model.as_deref()) {
614            cmd.arg("-m").arg(m);
615        }
616
617        // Add template variables as -p key value pairs
618        for (key, value) in variables {
619            // Validate variable names (alphanumeric + underscores)
620            if !key
621                .chars()
622                .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
623            {
624                return Err(Error::Validation(format!(
625                    "Template variable name contains invalid characters: '{}'",
626                    key
627                )));
628            }
629            // Validate variable values
630            let validated_value = validate_user_input(value)?;
631            cmd.arg("-p").arg(*key).arg(validated_value);
632        }
633
634        let output = cmd.output().map_err(Error::Io)?;
635
636        if !output.status.success() {
637            return Err(Error::Io(std::io::Error::other(
638                String::from_utf8_lossy(&output.stderr).to_string(),
639            )));
640        }
641
642        Ok(PromptResult {
643            response: String::from_utf8_lossy(&output.stdout).to_string(),
644            model: validated_model
645                .or(self.config.default_model.as_deref())
646                .unwrap_or("default")
647                .to_string(),
648            tokens_used: None,
649        })
650    }
651}
652
653impl Default for LlmCliClient {
654    fn default() -> Self {
655        Self::new().expect("Failed to create default LlmCliClient")
656    }
657}
658
659#[cfg(test)]
660mod tests {
661    use super::*;
662
663    #[test]
664    fn test_validate_model_name_valid() {
665        assert!(validate_model_name("gpt-4o-mini").is_ok());
666        assert!(validate_model_name("claude-sonnet-4").is_ok());
667        assert!(validate_model_name("sentence-transformers/all-MiniLM-L6-v2").is_ok());
668        assert!(validate_model_name("model:latest").is_ok());
669    }
670
671    #[test]
672    fn test_validate_model_name_invalid() {
673        assert!(validate_model_name("").is_err());
674        assert!(validate_model_name("model; rm -rf /").is_err());
675        assert!(validate_model_name("model$(whoami)").is_err());
676        assert!(validate_model_name("../../../etc/passwd").is_err());
677    }
678
679    #[test]
680    fn test_validate_collection_name_valid() {
681        assert!(validate_collection_name("my-collection").is_ok());
682        assert!(validate_collection_name("collection_123").is_ok());
683    }
684
685    #[test]
686    fn test_validate_collection_name_invalid() {
687        assert!(validate_collection_name("").is_err());
688        assert!(validate_collection_name("collection/path").is_err());
689        assert!(validate_collection_name("col; drop table").is_err());
690    }
691
692    #[test]
693    fn test_validate_db_path_valid() {
694        assert!(validate_db_path("/home/user/.llm/db.sqlite").is_ok());
695        assert!(validate_db_path("./data/embeddings.db").is_ok());
696    }
697
698    #[test]
699    fn test_validate_db_path_invalid() {
700        assert!(validate_db_path("").is_err());
701        assert!(validate_db_path("/path/../../../etc/passwd").is_err());
702        assert!(validate_db_path("/path$(whoami)/db").is_err());
703        assert!(validate_db_path("/path`id`/db").is_err());
704        assert!(validate_db_path("/path;rm -rf/db").is_err());
705    }
706
707    #[test]
708    fn test_format_dangerous_char() {
709        assert_eq!(format_dangerous_char('\n'), "\\n");
710        assert_eq!(format_dangerous_char('\r'), "\\r");
711        assert_eq!(format_dangerous_char('\0'), "\\0");
712        assert_eq!(format_dangerous_char('$'), "$");
713    }
714
715    #[test]
716    fn test_client_creation() {
717        let client = LlmCliClient::new();
718        assert!(client.is_ok());
719    }
720
721    #[test]
722    fn test_config_default() {
723        let config = LlmCliConfig::default();
724        assert_eq!(config.binary_path, PathBuf::from("llm"));
725        assert!(config.default_model.is_none());
726    }
727}