Skip to main content

lang_check/
engines.rs

1use crate::checker::{Diagnostic, Severity};
2use anyhow::Result;
3use extism::{Manifest, Plugin, Wasm};
4use harper_core::{
5    Dialect, Document, Lrc,
6    linting::{LintGroup, Linter},
7    parsers::Markdown,
8    spell::FstDictionary,
9};
10use serde::Deserialize;
11use std::path::PathBuf;
12use tracing::{debug, warn};
13
14#[async_trait::async_trait]
15pub trait Engine {
16    fn name(&self) -> &'static str;
17    async fn check(&mut self, text: &str, language_id: &str) -> Result<Vec<Diagnostic>>;
18}
19
20pub struct HarperEngine {
21    linter: LintGroup,
22    dict: Lrc<FstDictionary>,
23}
24
25impl Default for HarperEngine {
26    fn default() -> Self {
27        Self::new()
28    }
29}
30
31impl HarperEngine {
32    #[must_use]
33    pub fn new() -> Self {
34        let dict = FstDictionary::curated();
35        let linter = LintGroup::new_curated(dict.clone(), Dialect::American);
36        Self { linter, dict }
37    }
38}
39
40#[async_trait::async_trait]
41impl Engine for HarperEngine {
42    fn name(&self) -> &'static str {
43        "harper"
44    }
45
46    async fn check(&mut self, text: &str, _language_id: &str) -> Result<Vec<Diagnostic>> {
47        let document = Document::new(text, &Markdown::default(), self.dict.as_ref());
48        let lints = self.linter.lint(&document);
49
50        let diagnostics = lints
51            .into_iter()
52            .map(|lint| {
53                let suggestions = lint
54                    .suggestions
55                    .into_iter()
56                    .map(|s| match s {
57                        harper_core::linting::Suggestion::ReplaceWith(chars) => {
58                            chars.into_iter().collect::<String>()
59                        }
60                        harper_core::linting::Suggestion::InsertAfter(chars) => {
61                            let content: String = chars.into_iter().collect();
62                            format!("Insert \"{content}\"")
63                        }
64                        // Empty string replacement = delete the diagnostic range
65                        harper_core::linting::Suggestion::Remove => String::new(),
66                    })
67                    .collect();
68
69                Diagnostic {
70                    #[allow(clippy::cast_possible_truncation)]
71                    start_byte: lint.span.start as u32,
72                    #[allow(clippy::cast_possible_truncation)]
73                    end_byte: lint.span.end as u32,
74                    message: lint.message,
75                    suggestions,
76                    rule_id: format!("harper.{:?}", lint.lint_kind),
77                    severity: Severity::Warning as i32,
78                    unified_id: String::new(), // Will be filled by normalizer
79                    confidence: 0.8,
80                }
81            })
82            .collect();
83
84        Ok(diagnostics)
85    }
86}
87
88pub struct LanguageToolEngine {
89    url: String,
90    client: reqwest::Client,
91}
92
93#[derive(Deserialize)]
94struct LTResponse {
95    matches: Vec<LTMatch>,
96}
97
98#[derive(Deserialize)]
99struct LTMatch {
100    message: String,
101    offset: usize,
102    length: usize,
103    replacements: Vec<LTReplacement>,
104    rule: LTRule,
105}
106
107#[derive(Deserialize)]
108struct LTReplacement {
109    value: String,
110}
111
112#[derive(Deserialize)]
113#[serde(rename_all = "camelCase")]
114struct LTRule {
115    id: String,
116    issue_type: String,
117}
118
119impl LanguageToolEngine {
120    #[must_use]
121    pub fn new(url: String) -> Self {
122        let client = reqwest::Client::builder()
123            .connect_timeout(std::time::Duration::from_secs(3))
124            .timeout(std::time::Duration::from_secs(10))
125            .build()
126            .unwrap_or_default();
127        Self { url, client }
128    }
129}
130
131#[allow(clippy::too_many_lines, clippy::cast_possible_truncation)]
132#[async_trait::async_trait]
133impl Engine for LanguageToolEngine {
134    fn name(&self) -> &'static str {
135        "languagetool"
136    }
137
138    async fn check(&mut self, text: &str, language_id: &str) -> Result<Vec<Diagnostic>> {
139        let url = format!("{}/v2/check", self.url);
140
141        // Map file-type language IDs to BCP-47 codes that LanguageTool understands
142        let lt_lang = match language_id {
143            "markdown" | "html" | "latex" | "text" | "rst" | "asciidoc" | "typst" | "djot"
144            | "org" | "bibtex" | "forester" | "sweave" => "en-US",
145            other => other,
146        };
147
148        debug!(
149            url = %url,
150            language = lt_lang,
151            text_len = text.len(),
152            "LanguageTool request"
153        );
154
155        let request_start = std::time::Instant::now();
156        let response = match self
157            .client
158            .post(&url)
159            .form(&[("text", &text), ("language", &lt_lang)])
160            .send()
161            .await
162        {
163            Ok(r) => {
164                let status = r.status();
165                debug!(
166                    status = %status,
167                    elapsed_ms = request_start.elapsed().as_millis() as u64,
168                    "LanguageTool HTTP response"
169                );
170                if !status.is_success() {
171                    let body = r.text().await.unwrap_or_default();
172                    warn!(
173                        status = %status,
174                        body = %body,
175                        "LanguageTool returned non-200"
176                    );
177                    return Ok(vec![]);
178                }
179                r
180            }
181            Err(e) => {
182                warn!(
183                    elapsed_ms = request_start.elapsed().as_millis() as u64,
184                    "LanguageTool connection error: {e}"
185                );
186                return Ok(vec![]);
187            }
188        };
189
190        let res = match response.json::<LTResponse>().await {
191            Ok(r) => r,
192            Err(e) => {
193                warn!("LanguageTool JSON parse error: {e}");
194                return Ok(vec![]);
195            }
196        };
197
198        debug!(
199            matches = res.matches.len(),
200            elapsed_ms = request_start.elapsed().as_millis() as u64,
201            "LanguageTool check complete"
202        );
203
204        let diagnostics = res
205            .matches
206            .into_iter()
207            .map(|m| {
208                let severity = match m.rule.issue_type.as_str() {
209                    "misspelling" => Severity::Error,
210                    "typographical" => Severity::Warning,
211                    _ => Severity::Information,
212                };
213
214                Diagnostic {
215                    #[allow(clippy::cast_possible_truncation)]
216                    start_byte: m.offset as u32,
217                    #[allow(clippy::cast_possible_truncation)]
218                    end_byte: (m.offset + m.length) as u32,
219                    message: m.message,
220                    suggestions: m.replacements.into_iter().map(|r| r.value).collect(),
221                    rule_id: format!("languagetool.{}", m.rule.id),
222                    severity: severity as i32,
223                    unified_id: String::new(), // Will be filled by normalizer
224                    confidence: 0.7,
225                }
226            })
227            .collect();
228
229        Ok(diagnostics)
230    }
231}
232
233/// An external checker engine that communicates with a subprocess via stdin/stdout JSON.
234pub struct ExternalEngine {
235    name: String,
236    command: String,
237    args: Vec<String>,
238}
239
240impl ExternalEngine {
241    #[must_use]
242    pub const fn new(name: String, command: String, args: Vec<String>) -> Self {
243        Self {
244            name,
245            command,
246            args,
247        }
248    }
249}
250
251/// JSON request sent to the external process on stdin.
252#[derive(serde::Serialize)]
253struct ExternalRequest<'a> {
254    text: &'a str,
255    language_id: &'a str,
256}
257
258/// JSON diagnostic returned by the external process on stdout.
259#[derive(Deserialize)]
260struct ExternalDiagnostic {
261    start_byte: u32,
262    end_byte: u32,
263    message: String,
264    #[serde(default)]
265    suggestions: Vec<String>,
266    #[serde(default)]
267    rule_id: String,
268    #[serde(default = "default_severity_value")]
269    severity: i32,
270    #[serde(default)]
271    confidence: f32,
272}
273
274const fn default_severity_value() -> i32 {
275    Severity::Warning as i32
276}
277
278#[async_trait::async_trait]
279impl Engine for ExternalEngine {
280    fn name(&self) -> &'static str {
281        "external"
282    }
283
284    async fn check(&mut self, text: &str, language_id: &str) -> Result<Vec<Diagnostic>> {
285        use tokio::process::Command;
286
287        let request = ExternalRequest { text, language_id };
288        let input = serde_json::to_string(&request)?;
289
290        let output = match Command::new(&self.command)
291            .args(&self.args)
292            .stdin(std::process::Stdio::piped())
293            .stdout(std::process::Stdio::piped())
294            .stderr(std::process::Stdio::piped())
295            .spawn()
296        {
297            Ok(mut child) => {
298                use tokio::io::AsyncWriteExt;
299                if let Some(mut stdin) = child.stdin.take() {
300                    // Ignore write errors — the process may exit before reading stdin.
301                    let _ = stdin.write_all(input.as_bytes()).await;
302                    let _ = stdin.shutdown().await;
303                }
304                child.wait_with_output().await?
305            }
306            Err(e) => {
307                warn!(provider = %self.name, "Failed to spawn external provider: {e}");
308                return Ok(vec![]);
309            }
310        };
311
312        if !output.status.success() {
313            let stderr = String::from_utf8_lossy(&output.stderr);
314            warn!(
315                provider = %self.name,
316                status = %output.status,
317                stderr = stderr.trim(),
318                "External provider exited with error"
319            );
320            return Ok(vec![]);
321        }
322
323        let stdout = String::from_utf8_lossy(&output.stdout);
324        let ext_diagnostics: Vec<ExternalDiagnostic> = match serde_json::from_str(&stdout) {
325            Ok(d) => d,
326            Err(e) => {
327                warn!(provider = %self.name, "Failed to parse external provider output: {e}");
328                return Ok(vec![]);
329            }
330        };
331
332        let diagnostics = ext_diagnostics
333            .into_iter()
334            .map(|ed| {
335                let rule_id = if ed.rule_id.is_empty() {
336                    format!("external.{}", self.name)
337                } else {
338                    format!("external.{}.{}", self.name, ed.rule_id)
339                };
340                Diagnostic {
341                    start_byte: ed.start_byte,
342                    end_byte: ed.end_byte,
343                    message: ed.message,
344                    suggestions: ed.suggestions,
345                    rule_id,
346                    severity: ed.severity,
347                    unified_id: String::new(),
348                    confidence: if ed.confidence > 0.0 {
349                        ed.confidence
350                    } else {
351                        0.7
352                    },
353                }
354            })
355            .collect();
356
357        Ok(diagnostics)
358    }
359}
360
361/// A WASM checker plugin loaded via Extism.
362///
363/// The plugin must export a `check` function that accepts a JSON string
364/// `{"text": "...", "language_id": "..."}` and returns a JSON array of
365/// diagnostics matching the `ExternalDiagnostic` schema.
366pub struct WasmEngine {
367    name: String,
368    plugin: Plugin,
369}
370
371// SAFETY: Extism Plugin is not Send by default because it wraps a wasmtime Store
372// which holds raw pointers. However, we only ever access the plugin from a single
373// &mut self call at a time (the Engine trait takes &mut self), so this is safe
374// as long as we don't share across threads simultaneously.
375unsafe impl Send for WasmEngine {}
376
377impl WasmEngine {
378    /// Create a new WASM engine from a `.wasm` file path.
379    pub fn new(name: String, wasm_path: PathBuf) -> Result<Self> {
380        let wasm = Wasm::file(wasm_path);
381        let manifest = Manifest::new([wasm]);
382        let plugin = Plugin::new(&manifest, [], true)?;
383        Ok(Self { name, plugin })
384    }
385
386    /// Create a new WASM engine from raw bytes (useful for testing).
387    pub fn from_bytes(name: String, wasm_bytes: &[u8]) -> Result<Self> {
388        let wasm = Wasm::data(wasm_bytes.to_vec());
389        let manifest = Manifest::new([wasm]);
390        let plugin = Plugin::new(&manifest, [], true)?;
391        Ok(Self { name, plugin })
392    }
393}
394
395#[async_trait::async_trait]
396impl Engine for WasmEngine {
397    fn name(&self) -> &'static str {
398        "wasm"
399    }
400
401    async fn check(&mut self, text: &str, language_id: &str) -> Result<Vec<Diagnostic>> {
402        let request = serde_json::json!({
403            "text": text,
404            "language_id": language_id,
405        });
406        let input = request.to_string();
407
408        let output = match self.plugin.call::<&str, &str>("check", &input) {
409            Ok(result) => result.to_string(),
410            Err(e) => {
411                warn!(plugin = %self.name, "WASM plugin call failed: {e}");
412                return Ok(vec![]);
413            }
414        };
415
416        let ext_diagnostics: Vec<ExternalDiagnostic> = match serde_json::from_str(&output) {
417            Ok(d) => d,
418            Err(e) => {
419                warn!(plugin = %self.name, "Failed to parse WASM plugin output: {e}");
420                return Ok(vec![]);
421            }
422        };
423
424        let diagnostics = ext_diagnostics
425            .into_iter()
426            .map(|ed| {
427                let rule_id = if ed.rule_id.is_empty() {
428                    format!("wasm.{}", self.name)
429                } else {
430                    format!("wasm.{}.{}", self.name, ed.rule_id)
431                };
432                Diagnostic {
433                    start_byte: ed.start_byte,
434                    end_byte: ed.end_byte,
435                    message: ed.message,
436                    suggestions: ed.suggestions,
437                    rule_id,
438                    severity: ed.severity,
439                    unified_id: String::new(),
440                    confidence: if ed.confidence > 0.0 {
441                        ed.confidence
442                    } else {
443                        0.7
444                    },
445                }
446            })
447            .collect();
448
449        Ok(diagnostics)
450    }
451}
452
453/// Discover WASM plugins from a directory (e.g. `.languagecheck/plugins/`).
454/// Returns a list of (name, path) pairs for each `.wasm` file found.
455#[must_use]
456pub fn discover_wasm_plugins(plugin_dir: &std::path::Path) -> Vec<(String, PathBuf)> {
457    let Ok(entries) = std::fs::read_dir(plugin_dir) else {
458        return Vec::new();
459    };
460
461    entries
462        .filter_map(|entry| {
463            let entry = entry.ok()?;
464            let path = entry.path();
465            if path.extension().is_some_and(|e| e == "wasm") {
466                let name = path
467                    .file_stem()
468                    .map(|s| s.to_string_lossy().to_string())
469                    .unwrap_or_default();
470                Some((name, path))
471            } else {
472                None
473            }
474        })
475        .collect()
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481
482    #[tokio::test]
483    async fn test_harper_engine() -> Result<()> {
484        let mut engine = HarperEngine::new();
485        let text = "This is an test.";
486        let diagnostics = engine.check(text, "en-US").await?;
487
488        // Harper should find "an test" error
489        assert!(!diagnostics.is_empty());
490
491        Ok(())
492    }
493
494    #[tokio::test]
495    async fn external_engine_with_echo() -> Result<()> {
496        // Use a simple shell command that echoes a valid JSON response
497        let mut engine = ExternalEngine::new(
498            "test-provider".to_string(),
499            "sh".to_string(),
500            vec![
501                "-c".to_string(),
502                r#"cat > /dev/null; echo '[{"start_byte":0,"end_byte":4,"message":"test issue","suggestions":["fix"],"rule_id":"test.rule","severity":2}]'"#.to_string(),
503            ],
504        );
505
506        let diagnostics = engine.check("some text", "markdown").await?;
507        assert_eq!(diagnostics.len(), 1);
508        assert_eq!(diagnostics[0].message, "test issue");
509        assert_eq!(diagnostics[0].rule_id, "external.test-provider.test.rule");
510        assert_eq!(diagnostics[0].suggestions, vec!["fix"]);
511        assert_eq!(diagnostics[0].start_byte, 0);
512        assert_eq!(diagnostics[0].end_byte, 4);
513
514        Ok(())
515    }
516
517    #[tokio::test]
518    async fn external_engine_missing_binary() -> Result<()> {
519        let mut engine = ExternalEngine::new(
520            "nonexistent".to_string(),
521            "/nonexistent/binary".to_string(),
522            vec![],
523        );
524
525        // Should not error, just return empty
526        let diagnostics = engine.check("text", "markdown").await?;
527        assert!(diagnostics.is_empty());
528
529        Ok(())
530    }
531
532    #[tokio::test]
533    async fn external_engine_bad_json_output() -> Result<()> {
534        let mut engine = ExternalEngine::new(
535            "bad-json".to_string(),
536            "echo".to_string(),
537            vec!["not json".to_string()],
538        );
539
540        // Should not error, just return empty
541        let diagnostics = engine.check("text", "markdown").await?;
542        assert!(diagnostics.is_empty());
543
544        Ok(())
545    }
546
547    #[test]
548    fn wasm_engine_invalid_bytes_returns_error() {
549        let result = WasmEngine::from_bytes("bad-plugin".to_string(), b"not a wasm file");
550        assert!(result.is_err());
551    }
552
553    #[test]
554    fn wasm_engine_missing_file_returns_error() {
555        let result = WasmEngine::new(
556            "missing".to_string(),
557            PathBuf::from("/nonexistent/plugin.wasm"),
558        );
559        assert!(result.is_err());
560    }
561
562    #[test]
563    fn discover_wasm_plugins_empty_dir() {
564        let dir = std::env::temp_dir().join("lang_check_test_wasm_empty");
565        let _ = std::fs::remove_dir_all(&dir);
566        std::fs::create_dir_all(&dir).unwrap();
567
568        let plugins = discover_wasm_plugins(&dir);
569        assert!(plugins.is_empty());
570
571        let _ = std::fs::remove_dir_all(&dir);
572    }
573
574    #[test]
575    fn discover_wasm_plugins_finds_wasm_files() {
576        let dir = std::env::temp_dir().join("lang_check_test_wasm_discover");
577        let _ = std::fs::remove_dir_all(&dir);
578        std::fs::create_dir_all(&dir).unwrap();
579
580        // Create fake .wasm files and a non-wasm file
581        std::fs::write(dir.join("checker.wasm"), b"fake").unwrap();
582        std::fs::write(dir.join("linter.wasm"), b"fake").unwrap();
583        std::fs::write(dir.join("readme.txt"), b"not a plugin").unwrap();
584
585        let mut plugins = discover_wasm_plugins(&dir);
586        plugins.sort_by(|a, b| a.0.cmp(&b.0));
587
588        assert_eq!(plugins.len(), 2);
589        assert_eq!(plugins[0].0, "checker");
590        assert_eq!(plugins[1].0, "linter");
591        assert!(plugins[0].1.ends_with("checker.wasm"));
592        assert!(plugins[1].1.ends_with("linter.wasm"));
593
594        let _ = std::fs::remove_dir_all(&dir);
595    }
596
597    #[test]
598    fn discover_wasm_plugins_nonexistent_dir() {
599        let plugins = discover_wasm_plugins(std::path::Path::new("/nonexistent/dir"));
600        assert!(plugins.is_empty());
601    }
602
603    /// Live integration test — requires LT Docker on localhost:8010.
604    /// Run with: `cargo test lt_engine_live -- --ignored --nocapture`
605    #[tokio::test]
606    #[ignore]
607    async fn lt_engine_live() -> Result<()> {
608        // Initialize tracing for visible output
609        let _ = tracing_subscriber::fmt()
610            .with_env_filter("debug")
611            .with_writer(std::io::stderr)
612            .with_target(false)
613            .try_init();
614
615        let mut engine = LanguageToolEngine::new("http://localhost:8010".to_string());
616        let text = "This is a sentnce with erors.";
617        let diagnostics = engine.check(text, "markdown").await?;
618
619        println!("LT returned {} diagnostics:", diagnostics.len());
620        for d in &diagnostics {
621            println!(
622                "  [{}-{}] {} (rule: {}, suggestions: {:?})",
623                d.start_byte, d.end_byte, d.message, d.rule_id, d.suggestions
624            );
625        }
626
627        assert!(
628            diagnostics.len() >= 2,
629            "Expected at least 2 spelling errors, got {}",
630            diagnostics.len()
631        );
632        Ok(())
633    }
634
635    #[test]
636    fn lt_response_deserializes_camel_case() {
637        // Real LanguageTool API response (trimmed) — uses camelCase `issueType`
638        let json = r#"{
639            "matches": [{
640                "message": "Possible spelling mistake found.",
641                "offset": 10,
642                "length": 7,
643                "replacements": [{"value": "sentence"}],
644                "rule": {
645                    "id": "MORFOLOGIK_RULE_EN_US",
646                    "description": "Possible spelling mistake",
647                    "issueType": "misspelling",
648                    "category": {"id": "TYPOS", "name": "Possible Typo"}
649                }
650            }]
651        }"#;
652        let res: LTResponse = serde_json::from_str(json).unwrap();
653        assert_eq!(res.matches.len(), 1);
654        assert_eq!(res.matches[0].rule.id, "MORFOLOGIK_RULE_EN_US");
655        assert_eq!(res.matches[0].rule.issue_type, "misspelling");
656        assert_eq!(res.matches[0].offset, 10);
657        assert_eq!(res.matches[0].length, 7);
658        assert_eq!(res.matches[0].replacements[0].value, "sentence");
659    }
660}