Skip to main content

forja_llm/
local.rs

1use async_trait::async_trait;
2use forja_core::error::{ForjaError, Result};
3use forja_core::traits::LlmProvider;
4use forja_core::types::{Message, ToolDefinition};
5use std::path::{Path, PathBuf};
6use std::pin::Pin;
7use tokio_stream::Stream;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct LocalModelInfo {
11    pub file_name: String,
12    pub path: PathBuf,
13}
14
15pub fn ensure_models_dir(base_dir: &Path) -> std::io::Result<PathBuf> {
16    let models_dir = base_dir.join(".forja").join("models");
17    std::fs::create_dir_all(&models_dir)?;
18    Ok(models_dir)
19}
20
21pub fn detect_local_models(base_dir: &Path) -> std::io::Result<Vec<LocalModelInfo>> {
22    let models_dir = ensure_models_dir(base_dir)?;
23    let mut models = std::fs::read_dir(&models_dir)?
24        .filter_map(|entry| entry.ok())
25        .map(|entry| entry.path())
26        .filter(|path| {
27            path.extension()
28                .and_then(|extension| extension.to_str())
29                .map(|extension| extension.eq_ignore_ascii_case("gguf"))
30                .unwrap_or(false)
31        })
32        .filter_map(|path| {
33            let file_name = path.file_name()?.to_str()?.to_string();
34            Some(LocalModelInfo { file_name, path })
35        })
36        .collect::<Vec<_>>();
37
38    models.sort_by(|left, right| left.file_name.cmp(&right.file_name));
39    Ok(models)
40}
41
42#[derive(Debug, Clone)]
43pub struct LocalModelProvider {
44    model: LocalModelInfo,
45}
46
47impl LocalModelProvider {
48    pub fn new(model: LocalModelInfo) -> Self {
49        Self { model }
50    }
51
52    pub fn model(&self) -> &LocalModelInfo {
53        &self.model
54    }
55}
56
57#[async_trait]
58impl LlmProvider for LocalModelProvider {
59    async fn chat(
60        &self,
61        _messages: &[Message],
62        _tools: Option<&[ToolDefinition]>,
63    ) -> Result<Message> {
64        Err(ForjaError::LlmError(format!(
65            "Local GGUF inference is not implemented yet for {}",
66            self.model.file_name
67        )))
68    }
69
70    async fn stream(
71        &self,
72        _messages: &[Message],
73        _tools: Option<&[ToolDefinition]>,
74    ) -> Result<Pin<Box<dyn Stream<Item = Result<String>> + Send>>> {
75        Err(ForjaError::LlmError(format!(
76            "Local GGUF inference is not implemented yet for {}",
77            self.model.file_name
78        )))
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85    use std::time::{SystemTime, UNIX_EPOCH};
86
87    fn unique_temp_dir(name: &str) -> PathBuf {
88        let nanos = SystemTime::now()
89            .duration_since(UNIX_EPOCH)
90            .unwrap_or_default()
91            .as_nanos();
92        std::env::temp_dir().join(format!("forja_llm_{name}_{nanos}"))
93    }
94
95    #[test]
96    fn ensure_models_dir_creates_models_directory() {
97        let home_dir = unique_temp_dir("models_dir");
98
99        let models_dir = ensure_models_dir(&home_dir).unwrap();
100
101        assert!(models_dir.exists());
102
103        let _ = std::fs::remove_dir_all(home_dir);
104    }
105
106    #[test]
107    fn detect_local_models_finds_gguf_files() {
108        let home_dir = unique_temp_dir("detect_models");
109        let models_dir = ensure_models_dir(&home_dir).unwrap();
110        std::fs::write(models_dir.join("phi-4-mini.gguf"), "stub").unwrap();
111        std::fs::write(models_dir.join("notes.txt"), "ignore").unwrap();
112
113        let models = detect_local_models(&home_dir).unwrap();
114
115        assert_eq!(models.len(), 1);
116        assert_eq!(models[0].file_name, "phi-4-mini.gguf");
117
118        let _ = std::fs::remove_dir_all(home_dir);
119    }
120}