Skip to main content

lang_check/
engines.rs

1use crate::checker::{Diagnostic, Severity};
2use anyhow::Result;
3use extism::{Manifest, Plugin, Wasm};
4use harper_core::{
5    Dialect, Document, Lrc,
6    linting::{LintGroup, Linter},
7    parsers::Markdown,
8    spell::FstDictionary,
9};
10use serde::Deserialize;
11use std::path::PathBuf;
12use tracing::{debug, warn};
13
14#[async_trait::async_trait]
15pub trait Engine {
16    fn name(&self) -> &'static str;
17    async fn check(&mut self, text: &str, language_id: &str) -> Result<Vec<Diagnostic>>;
18    /// BCP-47 primary subtags this engine supports. Empty = all languages.
19    fn supported_languages(&self) -> Vec<&'static str> {
20        vec![]
21    }
22}
23
24/// Returns `true` if `engine` supports the given BCP-47 `lang_tag`.
25///
26/// Matching is on the primary subtag: `"en-US"` matches an engine that
27/// advertises `"en"`. An engine with an empty list is a wildcard (supports all).
28pub fn engine_supports_language(engine: &(dyn Engine + Send), lang_tag: &str) -> bool {
29    let supported = engine.supported_languages();
30    if supported.is_empty() {
31        return true;
32    }
33    let primary = lang_tag.split('-').next().unwrap_or(lang_tag);
34    supported.iter().any(|s| s.eq_ignore_ascii_case(primary))
35}
36
37pub struct HarperEngine {
38    linter: LintGroup,
39    dict: Lrc<FstDictionary>,
40}
41
42impl Default for HarperEngine {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl HarperEngine {
49    #[must_use]
50    pub fn new() -> Self {
51        let dict = FstDictionary::curated();
52        let linter = LintGroup::new_curated(dict.clone(), Dialect::American);
53        Self { linter, dict }
54    }
55}
56
57#[async_trait::async_trait]
58impl Engine for HarperEngine {
59    fn name(&self) -> &'static str {
60        "harper"
61    }
62
63    fn supported_languages(&self) -> Vec<&'static str> {
64        vec!["en"]
65    }
66
67    async fn check(&mut self, text: &str, _language_id: &str) -> Result<Vec<Diagnostic>> {
68        let document = Document::new(text, &Markdown::default(), self.dict.as_ref());
69        let lints = self.linter.lint(&document);
70
71        let diagnostics = lints
72            .into_iter()
73            .map(|lint| {
74                let suggestions = lint
75                    .suggestions
76                    .into_iter()
77                    .map(|s| match s {
78                        harper_core::linting::Suggestion::ReplaceWith(chars) => {
79                            chars.into_iter().collect::<String>()
80                        }
81                        harper_core::linting::Suggestion::InsertAfter(chars) => {
82                            let content: String = chars.into_iter().collect();
83                            format!("Insert \"{content}\"")
84                        }
85                        // Empty string replacement = delete the diagnostic range
86                        harper_core::linting::Suggestion::Remove => String::new(),
87                    })
88                    .collect();
89
90                Diagnostic {
91                    #[allow(clippy::cast_possible_truncation)]
92                    start_byte: lint.span.start as u32,
93                    #[allow(clippy::cast_possible_truncation)]
94                    end_byte: lint.span.end as u32,
95                    message: lint.message,
96                    suggestions,
97                    rule_id: format!("harper.{:?}", lint.lint_kind),
98                    severity: Severity::Warning as i32,
99                    unified_id: String::new(), // Will be filled by normalizer
100                    confidence: 0.8,
101                }
102            })
103            .collect();
104
105        Ok(diagnostics)
106    }
107}
108
109pub struct LanguageToolEngine {
110    url: String,
111    client: reqwest::Client,
112}
113
114#[derive(Deserialize)]
115struct LTResponse {
116    matches: Vec<LTMatch>,
117}
118
119#[derive(Deserialize)]
120struct LTMatch {
121    message: String,
122    offset: usize,
123    length: usize,
124    replacements: Vec<LTReplacement>,
125    rule: LTRule,
126}
127
128#[derive(Deserialize)]
129struct LTReplacement {
130    value: String,
131}
132
133#[derive(Deserialize)]
134#[serde(rename_all = "camelCase")]
135struct LTRule {
136    id: String,
137    issue_type: String,
138}
139
140impl LanguageToolEngine {
141    #[must_use]
142    pub fn new(url: String) -> Self {
143        let client = reqwest::Client::builder()
144            .connect_timeout(std::time::Duration::from_secs(3))
145            .timeout(std::time::Duration::from_secs(10))
146            .build()
147            .unwrap_or_default();
148        Self { url, client }
149    }
150}
151
152#[allow(clippy::too_many_lines, clippy::cast_possible_truncation)]
153#[async_trait::async_trait]
154impl Engine for LanguageToolEngine {
155    fn name(&self) -> &'static str {
156        "languagetool"
157    }
158
159    async fn check(&mut self, text: &str, language_id: &str) -> Result<Vec<Diagnostic>> {
160        let url = format!("{}/v2/check", self.url);
161
162        // language_id is now a BCP-47 tag from the orchestrator (e.g. "en-US", "de-DE")
163        let lt_lang = language_id;
164
165        debug!(
166            url = %url,
167            language = lt_lang,
168            text_len = text.len(),
169            "LanguageTool request"
170        );
171
172        let request_start = std::time::Instant::now();
173        let response = match self
174            .client
175            .post(&url)
176            .form(&[("text", &text), ("language", &lt_lang)])
177            .send()
178            .await
179        {
180            Ok(r) => {
181                let status = r.status();
182                debug!(
183                    status = %status,
184                    elapsed_ms = request_start.elapsed().as_millis() as u64,
185                    "LanguageTool HTTP response"
186                );
187                if !status.is_success() {
188                    let body = r.text().await.unwrap_or_default();
189                    warn!(
190                        status = %status,
191                        body = %body,
192                        "LanguageTool returned non-200"
193                    );
194                    return Err(anyhow::anyhow!("LanguageTool HTTP {status}: {body}"));
195                }
196                r
197            }
198            Err(e) => {
199                warn!(
200                    elapsed_ms = request_start.elapsed().as_millis() as u64,
201                    "LanguageTool connection error: {e}"
202                );
203                return Err(anyhow::anyhow!("LanguageTool connection error: {e}"));
204            }
205        };
206
207        let res = match response.json::<LTResponse>().await {
208            Ok(r) => r,
209            Err(e) => {
210                warn!("LanguageTool JSON parse error: {e}");
211                return Err(anyhow::anyhow!("LanguageTool JSON parse error: {e}"));
212            }
213        };
214
215        debug!(
216            matches = res.matches.len(),
217            elapsed_ms = request_start.elapsed().as_millis() as u64,
218            "LanguageTool check complete"
219        );
220
221        let diagnostics = res
222            .matches
223            .into_iter()
224            .map(|m| {
225                let severity = match m.rule.issue_type.as_str() {
226                    "misspelling" => Severity::Error,
227                    "typographical" => Severity::Warning,
228                    _ => Severity::Information,
229                };
230
231                Diagnostic {
232                    #[allow(clippy::cast_possible_truncation)]
233                    start_byte: m.offset as u32,
234                    #[allow(clippy::cast_possible_truncation)]
235                    end_byte: (m.offset + m.length) as u32,
236                    message: m.message,
237                    suggestions: m.replacements.into_iter().map(|r| r.value).collect(),
238                    rule_id: format!("languagetool.{}", m.rule.id),
239                    severity: severity as i32,
240                    unified_id: String::new(), // Will be filled by normalizer
241                    confidence: 0.8,
242                }
243            })
244            .collect();
245
246        Ok(diagnostics)
247    }
248}
249
250/// An external checker engine that communicates with a subprocess via stdin/stdout JSON.
251pub struct ExternalEngine {
252    name: String,
253    command: String,
254    args: Vec<String>,
255}
256
257impl ExternalEngine {
258    #[must_use]
259    pub const fn new(name: String, command: String, args: Vec<String>) -> Self {
260        Self {
261            name,
262            command,
263            args,
264        }
265    }
266}
267
268/// JSON request sent to the external process on stdin.
269#[derive(serde::Serialize)]
270struct ExternalRequest<'a> {
271    text: &'a str,
272    language_id: &'a str,
273}
274
275/// JSON diagnostic returned by the external process on stdout.
276#[derive(Deserialize)]
277struct ExternalDiagnostic {
278    start_byte: u32,
279    end_byte: u32,
280    message: String,
281    #[serde(default)]
282    suggestions: Vec<String>,
283    #[serde(default)]
284    rule_id: String,
285    #[serde(default = "default_severity_value")]
286    severity: i32,
287    #[serde(default)]
288    confidence: f32,
289}
290
291const fn default_severity_value() -> i32 {
292    Severity::Warning as i32
293}
294
295#[async_trait::async_trait]
296impl Engine for ExternalEngine {
297    fn name(&self) -> &'static str {
298        "external"
299    }
300
301    async fn check(&mut self, text: &str, language_id: &str) -> Result<Vec<Diagnostic>> {
302        use tokio::process::Command;
303
304        let request = ExternalRequest { text, language_id };
305        let input = serde_json::to_string(&request)?;
306
307        let output = match Command::new(&self.command)
308            .args(&self.args)
309            .stdin(std::process::Stdio::piped())
310            .stdout(std::process::Stdio::piped())
311            .stderr(std::process::Stdio::piped())
312            .spawn()
313        {
314            Ok(mut child) => {
315                use tokio::io::AsyncWriteExt;
316                if let Some(mut stdin) = child.stdin.take() {
317                    // Ignore write errors — the process may exit before reading stdin.
318                    let _ = stdin.write_all(input.as_bytes()).await;
319                    let _ = stdin.shutdown().await;
320                }
321                child.wait_with_output().await?
322            }
323            Err(e) => {
324                warn!(provider = %self.name, "Failed to spawn external provider: {e}");
325                return Ok(vec![]);
326            }
327        };
328
329        if !output.status.success() {
330            let stderr = String::from_utf8_lossy(&output.stderr);
331            warn!(
332                provider = %self.name,
333                status = %output.status,
334                stderr = stderr.trim(),
335                "External provider exited with error"
336            );
337            return Ok(vec![]);
338        }
339
340        let stdout = String::from_utf8_lossy(&output.stdout);
341        let ext_diagnostics: Vec<ExternalDiagnostic> = match serde_json::from_str(&stdout) {
342            Ok(d) => d,
343            Err(e) => {
344                warn!(provider = %self.name, "Failed to parse external provider output: {e}");
345                return Ok(vec![]);
346            }
347        };
348
349        let diagnostics = ext_diagnostics
350            .into_iter()
351            .map(|ed| {
352                let rule_id = if ed.rule_id.is_empty() {
353                    format!("external.{}", self.name)
354                } else {
355                    format!("external.{}.{}", self.name, ed.rule_id)
356                };
357                Diagnostic {
358                    start_byte: ed.start_byte,
359                    end_byte: ed.end_byte,
360                    message: ed.message,
361                    suggestions: ed.suggestions,
362                    rule_id,
363                    severity: ed.severity,
364                    unified_id: String::new(),
365                    confidence: if ed.confidence > 0.0 {
366                        ed.confidence
367                    } else {
368                        0.7
369                    },
370                }
371            })
372            .collect();
373
374        Ok(diagnostics)
375    }
376}
377
378/// A WASM checker plugin loaded via Extism.
379///
380/// The plugin must export a `check` function that accepts a JSON string
381/// `{"text": "...", "language_id": "..."}` and returns a JSON array of
382/// diagnostics matching the `ExternalDiagnostic` schema.
383pub struct WasmEngine {
384    name: String,
385    plugin: Plugin,
386}
387
388// SAFETY: Extism Plugin is not Send by default because it wraps a wasmtime Store
389// which holds raw pointers. However, we only ever access the plugin from a single
390// &mut self call at a time (the Engine trait takes &mut self), so this is safe
391// as long as we don't share across threads simultaneously.
392unsafe impl Send for WasmEngine {}
393
394impl WasmEngine {
395    /// Create a new WASM engine from a `.wasm` file path.
396    pub fn new(name: String, wasm_path: PathBuf) -> Result<Self> {
397        let wasm = Wasm::file(wasm_path);
398        let manifest = Manifest::new([wasm]);
399        let plugin = Plugin::new(&manifest, [], true)?;
400        Ok(Self { name, plugin })
401    }
402
403    /// Create a new WASM engine from raw bytes (useful for testing).
404    pub fn from_bytes(name: String, wasm_bytes: &[u8]) -> Result<Self> {
405        let wasm = Wasm::data(wasm_bytes.to_vec());
406        let manifest = Manifest::new([wasm]);
407        let plugin = Plugin::new(&manifest, [], true)?;
408        Ok(Self { name, plugin })
409    }
410}
411
412#[async_trait::async_trait]
413impl Engine for WasmEngine {
414    fn name(&self) -> &'static str {
415        "wasm"
416    }
417
418    async fn check(&mut self, text: &str, language_id: &str) -> Result<Vec<Diagnostic>> {
419        let request = serde_json::json!({
420            "text": text,
421            "language_id": language_id,
422        });
423        let input = request.to_string();
424
425        let output = match self.plugin.call::<&str, &str>("check", &input) {
426            Ok(result) => result.to_string(),
427            Err(e) => {
428                warn!(plugin = %self.name, "WASM plugin call failed: {e}");
429                return Ok(vec![]);
430            }
431        };
432
433        let ext_diagnostics: Vec<ExternalDiagnostic> = match serde_json::from_str(&output) {
434            Ok(d) => d,
435            Err(e) => {
436                warn!(plugin = %self.name, "Failed to parse WASM plugin output: {e}");
437                return Ok(vec![]);
438            }
439        };
440
441        let diagnostics = ext_diagnostics
442            .into_iter()
443            .map(|ed| {
444                let rule_id = if ed.rule_id.is_empty() {
445                    format!("wasm.{}", self.name)
446                } else {
447                    format!("wasm.{}.{}", self.name, ed.rule_id)
448                };
449                Diagnostic {
450                    start_byte: ed.start_byte,
451                    end_byte: ed.end_byte,
452                    message: ed.message,
453                    suggestions: ed.suggestions,
454                    rule_id,
455                    severity: ed.severity,
456                    unified_id: String::new(),
457                    confidence: if ed.confidence > 0.0 {
458                        ed.confidence
459                    } else {
460                        0.7
461                    },
462                }
463            })
464            .collect();
465
466        Ok(diagnostics)
467    }
468}
469
470/// Discover WASM plugins from a directory (e.g. `.languagecheck/plugins/`).
471/// Returns a list of (name, path) pairs for each `.wasm` file found.
472#[must_use]
473pub fn discover_wasm_plugins(plugin_dir: &std::path::Path) -> Vec<(String, PathBuf)> {
474    let Ok(entries) = std::fs::read_dir(plugin_dir) else {
475        return Vec::new();
476    };
477
478    entries
479        .filter_map(|entry| {
480            let entry = entry.ok()?;
481            let path = entry.path();
482            if path.extension().is_some_and(|e| e == "wasm") {
483                let name = path
484                    .file_stem()
485                    .map(|s| s.to_string_lossy().to_string())
486                    .unwrap_or_default();
487                Some((name, path))
488            } else {
489                None
490            }
491        })
492        .collect()
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498
499    #[tokio::test]
500    async fn test_harper_engine() -> Result<()> {
501        let mut engine = HarperEngine::new();
502        let text = "This is an test.";
503        let diagnostics = engine.check(text, "en-US").await?;
504
505        // Harper should find "an test" error
506        assert!(!diagnostics.is_empty());
507
508        Ok(())
509    }
510
511    #[tokio::test]
512    async fn external_engine_with_echo() -> Result<()> {
513        // Use a simple shell command that echoes a valid JSON response
514        let mut engine = ExternalEngine::new(
515            "test-provider".to_string(),
516            "sh".to_string(),
517            vec![
518                "-c".to_string(),
519                r#"cat > /dev/null; echo '[{"start_byte":0,"end_byte":4,"message":"test issue","suggestions":["fix"],"rule_id":"test.rule","severity":2}]'"#.to_string(),
520            ],
521        );
522
523        let diagnostics = engine.check("some text", "markdown").await?;
524        assert_eq!(diagnostics.len(), 1);
525        assert_eq!(diagnostics[0].message, "test issue");
526        assert_eq!(diagnostics[0].rule_id, "external.test-provider.test.rule");
527        assert_eq!(diagnostics[0].suggestions, vec!["fix"]);
528        assert_eq!(diagnostics[0].start_byte, 0);
529        assert_eq!(diagnostics[0].end_byte, 4);
530
531        Ok(())
532    }
533
534    #[tokio::test]
535    async fn external_engine_missing_binary() -> Result<()> {
536        let mut engine = ExternalEngine::new(
537            "nonexistent".to_string(),
538            "/nonexistent/binary".to_string(),
539            vec![],
540        );
541
542        // Should not error, just return empty
543        let diagnostics = engine.check("text", "markdown").await?;
544        assert!(diagnostics.is_empty());
545
546        Ok(())
547    }
548
549    #[tokio::test]
550    async fn external_engine_bad_json_output() -> Result<()> {
551        let mut engine = ExternalEngine::new(
552            "bad-json".to_string(),
553            "echo".to_string(),
554            vec!["not json".to_string()],
555        );
556
557        // Should not error, just return empty
558        let diagnostics = engine.check("text", "markdown").await?;
559        assert!(diagnostics.is_empty());
560
561        Ok(())
562    }
563
564    #[test]
565    fn wasm_engine_invalid_bytes_returns_error() {
566        let result = WasmEngine::from_bytes("bad-plugin".to_string(), b"not a wasm file");
567        assert!(result.is_err());
568    }
569
570    #[test]
571    fn wasm_engine_missing_file_returns_error() {
572        let result = WasmEngine::new(
573            "missing".to_string(),
574            PathBuf::from("/nonexistent/plugin.wasm"),
575        );
576        assert!(result.is_err());
577    }
578
579    #[test]
580    fn discover_wasm_plugins_empty_dir() {
581        let dir = std::env::temp_dir().join("lang_check_test_wasm_empty");
582        let _ = std::fs::remove_dir_all(&dir);
583        std::fs::create_dir_all(&dir).unwrap();
584
585        let plugins = discover_wasm_plugins(&dir);
586        assert!(plugins.is_empty());
587
588        let _ = std::fs::remove_dir_all(&dir);
589    }
590
591    #[test]
592    fn discover_wasm_plugins_finds_wasm_files() {
593        let dir = std::env::temp_dir().join("lang_check_test_wasm_discover");
594        let _ = std::fs::remove_dir_all(&dir);
595        std::fs::create_dir_all(&dir).unwrap();
596
597        // Create fake .wasm files and a non-wasm file
598        std::fs::write(dir.join("checker.wasm"), b"fake").unwrap();
599        std::fs::write(dir.join("linter.wasm"), b"fake").unwrap();
600        std::fs::write(dir.join("readme.txt"), b"not a plugin").unwrap();
601
602        let mut plugins = discover_wasm_plugins(&dir);
603        plugins.sort_by(|a, b| a.0.cmp(&b.0));
604
605        assert_eq!(plugins.len(), 2);
606        assert_eq!(plugins[0].0, "checker");
607        assert_eq!(plugins[1].0, "linter");
608        assert!(plugins[0].1.ends_with("checker.wasm"));
609        assert!(plugins[1].1.ends_with("linter.wasm"));
610
611        let _ = std::fs::remove_dir_all(&dir);
612    }
613
614    #[test]
615    fn discover_wasm_plugins_nonexistent_dir() {
616        let plugins = discover_wasm_plugins(std::path::Path::new("/nonexistent/dir"));
617        assert!(plugins.is_empty());
618    }
619
620    /// Live integration test — requires LT Docker on localhost:8010.
621    /// Run with: `cargo test lt_engine_live -- --ignored --nocapture`
622    #[tokio::test]
623    #[ignore]
624    async fn lt_engine_live() -> Result<()> {
625        // Initialize tracing for visible output
626        let _ = tracing_subscriber::fmt()
627            .with_env_filter("debug")
628            .with_writer(std::io::stderr)
629            .with_target(false)
630            .try_init();
631
632        let mut engine = LanguageToolEngine::new("http://localhost:8010".to_string());
633        let text = "This is a sentnce with erors.";
634        let diagnostics = engine.check(text, "markdown").await?;
635
636        println!("LT returned {} diagnostics:", diagnostics.len());
637        for d in &diagnostics {
638            println!(
639                "  [{}-{}] {} (rule: {}, suggestions: {:?})",
640                d.start_byte, d.end_byte, d.message, d.rule_id, d.suggestions
641            );
642        }
643
644        assert!(
645            diagnostics.len() >= 2,
646            "Expected at least 2 spelling errors, got {}",
647            diagnostics.len()
648        );
649        Ok(())
650    }
651
652    #[test]
653    fn lt_response_deserializes_camel_case() {
654        // Real LanguageTool API response (trimmed) — uses camelCase `issueType`
655        let json = r#"{
656            "matches": [{
657                "message": "Possible spelling mistake found.",
658                "offset": 10,
659                "length": 7,
660                "replacements": [{"value": "sentence"}],
661                "rule": {
662                    "id": "MORFOLOGIK_RULE_EN_US",
663                    "description": "Possible spelling mistake",
664                    "issueType": "misspelling",
665                    "category": {"id": "TYPOS", "name": "Possible Typo"}
666                }
667            }]
668        }"#;
669        let res: LTResponse = serde_json::from_str(json).unwrap();
670        assert_eq!(res.matches.len(), 1);
671        assert_eq!(res.matches[0].rule.id, "MORFOLOGIK_RULE_EN_US");
672        assert_eq!(res.matches[0].rule.issue_type, "misspelling");
673        assert_eq!(res.matches[0].offset, 10);
674        assert_eq!(res.matches[0].length, 7);
675        assert_eq!(res.matches[0].replacements[0].value, "sentence");
676    }
677}