Skip to main content

construct/rag/
mod.rs

1//! RAG pipeline for hardware datasheet retrieval.
2//!
3//! Supports:
4//! - Markdown and text datasheets (always)
5//! - PDF ingestion (with `rag-pdf` feature)
6//! - Pin/alias tables (e.g. `red_led: 13`) for explicit lookup
7//! - Keyword retrieval (default) or semantic search via embeddings (optional)
8
9use crate::memory::chunker;
10use std::collections::HashMap;
11use std::path::Path;
12
13/// A chunk of datasheet content with board metadata.
14#[derive(Debug, Clone)]
15pub struct DatasheetChunk {
16    /// Board this chunk applies to (e.g. "nucleo-f401re", "rpi-gpio"), or None for generic.
17    pub board: Option<String>,
18    /// Source file path (for debugging).
19    pub source: String,
20    /// Chunk content.
21    pub content: String,
22}
23
24/// Pin alias: human-readable name → pin number (e.g. "red_led" → 13).
25pub type PinAliases = HashMap<String, u32>;
26
27/// Parse pin aliases from markdown. Looks for:
28/// - `## Pin Aliases` section with `alias: pin` lines
29/// - Markdown table `| alias | pin |`
30fn parse_pin_aliases(content: &str) -> PinAliases {
31    let mut aliases = PinAliases::new();
32    let content_lower = content.to_lowercase();
33
34    // Find ## Pin Aliases section
35    let section_markers = ["## pin aliases", "## pin alias", "## pins"];
36    let mut in_section = false;
37    let mut section_start = 0;
38
39    for marker in section_markers {
40        if let Some(pos) = content_lower.find(marker) {
41            in_section = true;
42            section_start = pos + marker.len();
43            break;
44        }
45    }
46
47    if !in_section {
48        return aliases;
49    }
50
51    let rest = &content[section_start..];
52    let section_end = rest
53        .find("\n## ")
54        .map(|i| section_start + i)
55        .unwrap_or(content.len());
56    let section = &content[section_start..section_end];
57
58    // Parse "alias: pin" or "alias = pin" lines
59    for line in section.lines() {
60        let line = line.trim();
61        if line.is_empty() {
62            continue;
63        }
64        // Table row: | red_led | 13 | (skip header | alias | pin | and separator |---|)
65        if line.starts_with('|') {
66            let parts: Vec<&str> = line.split('|').map(|s| s.trim()).collect();
67            if parts.len() >= 3 {
68                let alias = parts[1].trim().to_lowercase().replace(' ', "_");
69                let pin_str = parts[2].trim();
70                // Skip header row and separator (|---|)
71                if alias.eq("alias")
72                    || alias.eq("pin")
73                    || pin_str.eq("pin")
74                    || alias.contains("---")
75                    || pin_str.contains("---")
76                {
77                    continue;
78                }
79                if let Ok(pin) = pin_str.parse::<u32>() {
80                    if !alias.is_empty() {
81                        aliases.insert(alias, pin);
82                    }
83                }
84            }
85            continue;
86        }
87        // Key: value
88        if let Some((k, v)) = line.split_once(':').or_else(|| line.split_once('=')) {
89            let alias = k.trim().to_lowercase().replace(' ', "_");
90            if let Ok(pin) = v.trim().parse::<u32>() {
91                if !alias.is_empty() {
92                    aliases.insert(alias, pin);
93                }
94            }
95        }
96    }
97
98    aliases
99}
100
101fn collect_md_txt_paths(dir: &Path, out: &mut Vec<std::path::PathBuf>) {
102    let Ok(entries) = std::fs::read_dir(dir) else {
103        return;
104    };
105    for entry in entries.flatten() {
106        let path = entry.path();
107        if path.is_dir() {
108            collect_md_txt_paths(&path, out);
109        } else if path.is_file() {
110            let ext = path.extension().and_then(|e| e.to_str());
111            if ext == Some("md") || ext == Some("txt") {
112                out.push(path);
113            }
114        }
115    }
116}
117
118#[cfg(feature = "rag-pdf")]
119fn collect_pdf_paths(dir: &Path, out: &mut Vec<std::path::PathBuf>) {
120    let Ok(entries) = std::fs::read_dir(dir) else {
121        return;
122    };
123    for entry in entries.flatten() {
124        let path = entry.path();
125        if path.is_dir() {
126            collect_pdf_paths(&path, out);
127        } else if path.is_file() {
128            if path.extension().and_then(|e| e.to_str()) == Some("pdf") {
129                out.push(path);
130            }
131        }
132    }
133}
134
135#[cfg(feature = "rag-pdf")]
136fn extract_pdf_text(path: &Path) -> Option<String> {
137    let bytes = std::fs::read(path).ok()?;
138    pdf_extract::extract_text_from_mem(&bytes).ok()
139}
140
141/// Hardware RAG index — loads and retrieves datasheet chunks.
142pub struct HardwareRag {
143    chunks: Vec<DatasheetChunk>,
144    /// Per-board pin aliases (board -> alias -> pin).
145    pin_aliases: HashMap<String, PinAliases>,
146}
147
148impl HardwareRag {
149    /// Load datasheets from a directory. Expects .md, .txt, and optionally .pdf (with rag-pdf).
150    /// Filename (without extension) is used as board tag.
151    /// Supports `## Pin Aliases` section for explicit alias→pin mapping.
152    pub fn load(workspace_dir: &Path, datasheet_dir: &str) -> anyhow::Result<Self> {
153        let base = workspace_dir.join(datasheet_dir);
154        if !base.exists() || !base.is_dir() {
155            return Ok(Self {
156                chunks: Vec::new(),
157                pin_aliases: HashMap::new(),
158            });
159        }
160
161        let mut paths: Vec<std::path::PathBuf> = Vec::new();
162        collect_md_txt_paths(&base, &mut paths);
163        #[cfg(feature = "rag-pdf")]
164        collect_pdf_paths(&base, &mut paths);
165
166        let mut chunks = Vec::new();
167        let mut pin_aliases: HashMap<String, PinAliases> = HashMap::new();
168        let max_tokens = 512;
169
170        for path in paths {
171            let content = if path.extension().and_then(|e| e.to_str()) == Some("pdf") {
172                #[cfg(feature = "rag-pdf")]
173                {
174                    extract_pdf_text(&path).unwrap_or_default()
175                }
176                #[cfg(not(feature = "rag-pdf"))]
177                {
178                    String::new()
179                }
180            } else {
181                std::fs::read_to_string(&path).unwrap_or_default()
182            };
183
184            if content.trim().is_empty() {
185                continue;
186            }
187
188            let board = infer_board_from_path(&path, &base);
189            let source = path
190                .strip_prefix(workspace_dir)
191                .unwrap_or(&path)
192                .display()
193                .to_string();
194
195            // Parse pin aliases from full content
196            let aliases = parse_pin_aliases(&content);
197            if let Some(ref b) = board {
198                if !aliases.is_empty() {
199                    pin_aliases.insert(b.clone(), aliases);
200                }
201            }
202
203            for chunk in chunker::chunk_markdown(&content, max_tokens) {
204                chunks.push(DatasheetChunk {
205                    board: board.clone(),
206                    source: source.clone(),
207                    content: chunk.content,
208                });
209            }
210        }
211
212        Ok(Self {
213            chunks,
214            pin_aliases,
215        })
216    }
217
218    /// Get pin aliases for a board (e.g. "red_led" -> 13).
219    pub fn pin_aliases_for_board(&self, board: &str) -> Option<&PinAliases> {
220        self.pin_aliases.get(board)
221    }
222
223    /// Build pin-alias context for query. When user says "red led", inject "red_led: 13" for matching boards.
224    pub fn pin_alias_context(&self, query: &str, boards: &[String]) -> String {
225        let query_lower = query.to_lowercase();
226        let query_words: Vec<&str> = query_lower
227            .split_whitespace()
228            .filter(|w| w.len() > 1)
229            .collect();
230
231        let mut lines = Vec::new();
232        for board in boards {
233            if let Some(aliases) = self.pin_aliases.get(board) {
234                for (alias, pin) in aliases {
235                    let alias_words: Vec<&str> = alias.split('_').collect();
236                    let matches = query_words.iter().any(|qw| alias_words.contains(qw))
237                        || query_lower.contains(&alias.replace('_', " "));
238                    if matches {
239                        lines.push(format!("{board}: {alias} = pin {pin}"));
240                    }
241                }
242            }
243        }
244        if lines.is_empty() {
245            return String::new();
246        }
247        format!("[Pin aliases for query]\n{}\n\n", lines.join("\n"))
248    }
249
250    /// Retrieve chunks relevant to the query and boards.
251    /// Uses keyword matching and board filter. Pin-alias context is built separately via `pin_alias_context`.
252    pub fn retrieve(&self, query: &str, boards: &[String], limit: usize) -> Vec<&DatasheetChunk> {
253        if self.chunks.is_empty() || limit == 0 {
254            return Vec::new();
255        }
256
257        let query_lower = query.to_lowercase();
258        let query_terms: Vec<&str> = query_lower
259            .split_whitespace()
260            .filter(|w| w.len() > 2)
261            .collect();
262
263        let mut scored: Vec<(&DatasheetChunk, f32)> = Vec::new();
264        for chunk in &self.chunks {
265            let content_lower = chunk.content.to_lowercase();
266            let mut score = 0.0f32;
267
268            for term in &query_terms {
269                if content_lower.contains(term) {
270                    score += 1.0;
271                }
272            }
273
274            if score > 0.0 {
275                let board_match = chunk.board.as_ref().map_or(false, |b| boards.contains(b));
276                if board_match {
277                    score += 2.0;
278                }
279                scored.push((chunk, score));
280            }
281        }
282
283        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
284        scored.truncate(limit);
285        scored.into_iter().map(|(c, _)| c).collect()
286    }
287
288    /// Number of indexed chunks.
289    pub fn len(&self) -> usize {
290        self.chunks.len()
291    }
292
293    /// True if no chunks are indexed.
294    pub fn is_empty(&self) -> bool {
295        self.chunks.is_empty()
296    }
297}
298
299/// Infer board tag from file path. `nucleo-f401re.md` → Some("nucleo-f401re").
300fn infer_board_from_path(path: &Path, base: &Path) -> Option<String> {
301    let rel = path.strip_prefix(base).ok()?;
302    let stem = path.file_stem()?.to_str()?;
303
304    if stem == "generic" || stem.starts_with("generic_") {
305        return None;
306    }
307    if rel.parent().and_then(|p| p.to_str()) == Some("_generic") {
308        return None;
309    }
310
311    Some(stem.to_string())
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317
318    #[test]
319    fn parse_pin_aliases_key_value() {
320        let md = r#"## Pin Aliases
321red_led: 13
322builtin_led: 13
323user_led: 5"#;
324        let a = parse_pin_aliases(md);
325        assert_eq!(a.get("red_led"), Some(&13));
326        assert_eq!(a.get("builtin_led"), Some(&13));
327        assert_eq!(a.get("user_led"), Some(&5));
328    }
329
330    #[test]
331    fn parse_pin_aliases_table() {
332        let md = r#"## Pin Aliases
333| alias | pin |
334|-------|-----|
335| red_led | 13 |
336| builtin_led | 13 |"#;
337        let a = parse_pin_aliases(md);
338        assert_eq!(a.get("red_led"), Some(&13));
339        assert_eq!(a.get("builtin_led"), Some(&13));
340    }
341
342    #[test]
343    fn parse_pin_aliases_empty() {
344        let a = parse_pin_aliases("No aliases here");
345        assert!(a.is_empty());
346    }
347
348    #[test]
349    fn infer_board_from_path_nucleo() {
350        let base = std::path::Path::new("/base");
351        let path = std::path::Path::new("/base/nucleo-f401re.md");
352        assert_eq!(
353            infer_board_from_path(path, base),
354            Some("nucleo-f401re".into())
355        );
356    }
357
358    #[test]
359    fn infer_board_generic_none() {
360        let base = std::path::Path::new("/base");
361        let path = std::path::Path::new("/base/generic.md");
362        assert_eq!(infer_board_from_path(path, base), None);
363    }
364
365    #[test]
366    fn hardware_rag_load_and_retrieve() {
367        let tmp = tempfile::tempdir().unwrap();
368        let base = tmp.path().join("datasheets");
369        std::fs::create_dir_all(&base).unwrap();
370        let content = r#"# Test Board
371## Pin Aliases
372red_led: 13
373## GPIO
374Pin 13: LED
375"#;
376        std::fs::write(base.join("test-board.md"), content).unwrap();
377
378        let rag = HardwareRag::load(tmp.path(), "datasheets").unwrap();
379        assert!(!rag.is_empty());
380        let boards = vec!["test-board".to_string()];
381        let chunks = rag.retrieve("led", &boards, 5);
382        assert!(!chunks.is_empty());
383        let ctx = rag.pin_alias_context("red led", &boards);
384        assert!(ctx.contains("13"));
385    }
386
387    #[test]
388    fn hardware_rag_load_empty_dir() {
389        let tmp = tempfile::tempdir().unwrap();
390        let base = tmp.path().join("empty_ds");
391        std::fs::create_dir_all(&base).unwrap();
392        let rag = HardwareRag::load(tmp.path(), "empty_ds").unwrap();
393        assert!(rag.is_empty());
394    }
395}