Skip to main content

lang_check/engines/
mod.rs

1mod proselint;
2mod vale;
3
4pub use proselint::ProselintEngine;
5pub use vale::ValeEngine;
6
7use crate::checker::{Diagnostic, Severity};
8use anyhow::Result;
9use extism::{Manifest, Plugin, Wasm};
10use harper_core::{
11    Dialect, Document, Lrc,
12    linting::{LintGroup, Linter},
13    parsers::Markdown,
14    spell::FstDictionary,
15};
16use serde::Deserialize;
17use std::path::PathBuf;
18use tracing::{debug, warn};
19
20#[async_trait::async_trait]
21pub trait Engine {
22    fn name(&self) -> &'static str;
23    async fn check(&mut self, text: &str, language_id: &str) -> Result<Vec<Diagnostic>>;
24    /// BCP-47 primary subtags this engine supports. Empty = all languages.
25    fn supported_languages(&self) -> Vec<&'static str> {
26        vec![]
27    }
28}
29
30/// Returns `true` if `engine` supports the given BCP-47 `lang_tag`.
31///
32/// Matching is on the primary subtag: `"en-US"` matches an engine that
33/// advertises `"en"`. An engine with an empty list is a wildcard (supports all).
34pub fn engine_supports_language(engine: &(dyn Engine + Send), lang_tag: &str) -> bool {
35    let supported = engine.supported_languages();
36    if supported.is_empty() {
37        return true;
38    }
39    let primary = lang_tag.split('-').next().unwrap_or(lang_tag);
40    supported.iter().any(|s| s.eq_ignore_ascii_case(primary))
41}
42
43pub struct HarperEngine {
44    linter: LintGroup,
45    dict: Lrc<FstDictionary>,
46}
47
48impl HarperEngine {
49    #[must_use]
50    pub fn new(config: &crate::config::HarperConfig) -> Self {
51        let dialect = match config.dialect.as_str() {
52            "British" => Dialect::British,
53            "Canadian" => Dialect::Canadian,
54            "Australian" => Dialect::Australian,
55            _ => Dialect::American,
56        };
57        let dict = FstDictionary::curated();
58        let mut linter = LintGroup::new_curated(dict.clone(), dialect);
59
60        for (rule, enabled) in &config.linters {
61            linter.config.set_rule_enabled(rule, *enabled);
62        }
63
64        Self { linter, dict }
65    }
66}
67
68#[async_trait::async_trait]
69impl Engine for HarperEngine {
70    fn name(&self) -> &'static str {
71        "harper"
72    }
73
74    fn supported_languages(&self) -> Vec<&'static str> {
75        vec!["en"]
76    }
77
78    async fn check(&mut self, text: &str, _language_id: &str) -> Result<Vec<Diagnostic>> {
79        let document = Document::new(text, &Markdown::default(), self.dict.as_ref());
80        let lints = self.linter.lint(&document);
81
82        let diagnostics = lints
83            .into_iter()
84            .map(|lint| {
85                let suggestions = lint
86                    .suggestions
87                    .into_iter()
88                    .map(|s| match s {
89                        harper_core::linting::Suggestion::ReplaceWith(chars) => {
90                            chars.into_iter().collect::<String>()
91                        }
92                        harper_core::linting::Suggestion::InsertAfter(chars) => {
93                            let content: String = chars.into_iter().collect();
94                            format!("Insert \"{content}\"")
95                        }
96                        // Empty string replacement = delete the diagnostic range
97                        harper_core::linting::Suggestion::Remove => String::new(),
98                    })
99                    .collect();
100
101                Diagnostic {
102                    #[allow(clippy::cast_possible_truncation)]
103                    start_byte: lint.span.start as u32,
104                    #[allow(clippy::cast_possible_truncation)]
105                    end_byte: lint.span.end as u32,
106                    message: lint.message,
107                    suggestions,
108                    rule_id: format!("harper.{:?}", lint.lint_kind),
109                    severity: Severity::Warning as i32,
110                    unified_id: String::new(), // Will be filled by normalizer
111                    confidence: 0.8,
112                }
113            })
114            .collect();
115
116        Ok(diagnostics)
117    }
118}
119
120pub struct LanguageToolEngine {
121    url: String,
122    level: String,
123    mother_tongue: Option<String>,
124    disabled_rules: Vec<String>,
125    enabled_rules: Vec<String>,
126    disabled_categories: Vec<String>,
127    enabled_categories: Vec<String>,
128    client: reqwest::Client,
129}
130
131#[derive(Deserialize)]
132struct LTResponse {
133    matches: Vec<LTMatch>,
134}
135
136#[derive(Deserialize)]
137struct LTMatch {
138    message: String,
139    offset: usize,
140    length: usize,
141    replacements: Vec<LTReplacement>,
142    rule: LTRule,
143}
144
145#[derive(Deserialize)]
146struct LTReplacement {
147    value: String,
148}
149
150#[derive(Deserialize)]
151#[serde(rename_all = "camelCase")]
152struct LTRule {
153    id: String,
154    issue_type: String,
155}
156
157impl LanguageToolEngine {
158    #[must_use]
159    pub fn new(config: &crate::config::LanguageToolConfig) -> Self {
160        let client = reqwest::Client::builder()
161            .connect_timeout(std::time::Duration::from_secs(3))
162            .timeout(std::time::Duration::from_secs(10))
163            .build()
164            .unwrap_or_default();
165        Self {
166            url: config.url.clone(),
167            level: config.level.clone(),
168            mother_tongue: config.mother_tongue.clone(),
169            disabled_rules: config.disabled_rules.clone(),
170            enabled_rules: config.enabled_rules.clone(),
171            disabled_categories: config.disabled_categories.clone(),
172            enabled_categories: config.enabled_categories.clone(),
173            client,
174        }
175    }
176}
177
178#[allow(clippy::too_many_lines, clippy::cast_possible_truncation)]
179#[async_trait::async_trait]
180impl Engine for LanguageToolEngine {
181    fn name(&self) -> &'static str {
182        "languagetool"
183    }
184
185    async fn check(&mut self, text: &str, language_id: &str) -> Result<Vec<Diagnostic>> {
186        let url = format!("{}/v2/check", self.url);
187
188        // language_id is now a BCP-47 tag from the orchestrator (e.g. "en-US", "de-DE")
189        let lt_lang = language_id;
190
191        debug!(
192            url = %url,
193            language = lt_lang,
194            text_len = text.len(),
195            "LanguageTool request"
196        );
197
198        let mut form_params: Vec<(&str, String)> = vec![
199            ("text", text.to_string()),
200            ("language", lt_lang.to_string()),
201        ];
202        if self.level != "default" {
203            form_params.push(("level", self.level.clone()));
204        }
205        if let Some(ref mt) = self.mother_tongue {
206            form_params.push(("motherTongue", mt.clone()));
207        }
208        if !self.disabled_rules.is_empty() {
209            form_params.push(("disabledRules", self.disabled_rules.join(",")));
210        }
211        if !self.enabled_rules.is_empty() {
212            form_params.push(("enabledRules", self.enabled_rules.join(",")));
213        }
214        if !self.disabled_categories.is_empty() {
215            form_params.push(("disabledCategories", self.disabled_categories.join(",")));
216        }
217        if !self.enabled_categories.is_empty() {
218            form_params.push(("enabledCategories", self.enabled_categories.join(",")));
219        }
220
221        let request_start = std::time::Instant::now();
222        let response = match self.client.post(&url).form(&form_params).send().await {
223            Ok(r) => {
224                let status = r.status();
225                debug!(
226                    status = %status,
227                    elapsed_ms = request_start.elapsed().as_millis() as u64,
228                    "LanguageTool HTTP response"
229                );
230                if !status.is_success() {
231                    let body = r.text().await.unwrap_or_default();
232                    warn!(
233                        status = %status,
234                        body = %body,
235                        "LanguageTool returned non-200"
236                    );
237                    return Err(anyhow::anyhow!("LanguageTool HTTP {status}: {body}"));
238                }
239                r
240            }
241            Err(e) => {
242                warn!(
243                    elapsed_ms = request_start.elapsed().as_millis() as u64,
244                    "LanguageTool connection error: {e}"
245                );
246                return Err(anyhow::anyhow!("LanguageTool connection error: {e}"));
247            }
248        };
249
250        let res = match response.json::<LTResponse>().await {
251            Ok(r) => r,
252            Err(e) => {
253                warn!("LanguageTool JSON parse error: {e}");
254                return Err(anyhow::anyhow!("LanguageTool JSON parse error: {e}"));
255            }
256        };
257
258        debug!(
259            matches = res.matches.len(),
260            elapsed_ms = request_start.elapsed().as_millis() as u64,
261            "LanguageTool check complete"
262        );
263
264        let diagnostics = res
265            .matches
266            .into_iter()
267            .map(|m| {
268                let severity = match m.rule.issue_type.as_str() {
269                    "misspelling" => Severity::Error,
270                    "typographical" => Severity::Warning,
271                    _ => Severity::Information,
272                };
273
274                Diagnostic {
275                    #[allow(clippy::cast_possible_truncation)]
276                    start_byte: m.offset as u32,
277                    #[allow(clippy::cast_possible_truncation)]
278                    end_byte: (m.offset + m.length) as u32,
279                    message: m.message,
280                    suggestions: m.replacements.into_iter().map(|r| r.value).collect(),
281                    rule_id: format!("languagetool.{}", m.rule.id),
282                    severity: severity as i32,
283                    unified_id: String::new(), // Will be filled by normalizer
284                    confidence: 0.8,
285                }
286            })
287            .collect();
288
289        Ok(diagnostics)
290    }
291}
292
293/// An external checker engine that communicates with a subprocess via stdin/stdout JSON.
294pub struct ExternalEngine {
295    name: String,
296    command: String,
297    args: Vec<String>,
298}
299
300impl ExternalEngine {
301    #[must_use]
302    pub const fn new(name: String, command: String, args: Vec<String>) -> Self {
303        Self {
304            name,
305            command,
306            args,
307        }
308    }
309}
310
311/// JSON request sent to the external process on stdin.
312#[derive(serde::Serialize)]
313struct ExternalRequest<'a> {
314    text: &'a str,
315    language_id: &'a str,
316}
317
318/// JSON diagnostic returned by the external process on stdout.
319#[derive(Deserialize)]
320struct ExternalDiagnostic {
321    start_byte: u32,
322    end_byte: u32,
323    message: String,
324    #[serde(default)]
325    suggestions: Vec<String>,
326    #[serde(default)]
327    rule_id: String,
328    #[serde(default = "default_severity_value")]
329    severity: i32,
330    #[serde(default)]
331    confidence: f32,
332}
333
334const fn default_severity_value() -> i32 {
335    Severity::Warning as i32
336}
337
338#[async_trait::async_trait]
339impl Engine for ExternalEngine {
340    fn name(&self) -> &'static str {
341        "external"
342    }
343
344    async fn check(&mut self, text: &str, language_id: &str) -> Result<Vec<Diagnostic>> {
345        use tokio::process::Command;
346
347        let request = ExternalRequest { text, language_id };
348        let input = serde_json::to_string(&request)?;
349
350        let output = match Command::new(&self.command)
351            .args(&self.args)
352            .stdin(std::process::Stdio::piped())
353            .stdout(std::process::Stdio::piped())
354            .stderr(std::process::Stdio::piped())
355            .spawn()
356        {
357            Ok(mut child) => {
358                use tokio::io::AsyncWriteExt;
359                if let Some(mut stdin) = child.stdin.take() {
360                    // Ignore write errors — the process may exit before reading stdin.
361                    let _ = stdin.write_all(input.as_bytes()).await;
362                    let _ = stdin.shutdown().await;
363                }
364                child.wait_with_output().await?
365            }
366            Err(e) => {
367                warn!(provider = %self.name, "Failed to spawn external provider: {e}");
368                return Ok(vec![]);
369            }
370        };
371
372        if !output.status.success() {
373            let stderr = String::from_utf8_lossy(&output.stderr);
374            warn!(
375                provider = %self.name,
376                status = %output.status,
377                stderr = stderr.trim(),
378                "External provider exited with error"
379            );
380            return Ok(vec![]);
381        }
382
383        let stdout = String::from_utf8_lossy(&output.stdout);
384        let ext_diagnostics: Vec<ExternalDiagnostic> = match serde_json::from_str(&stdout) {
385            Ok(d) => d,
386            Err(e) => {
387                warn!(provider = %self.name, "Failed to parse external provider output: {e}");
388                return Ok(vec![]);
389            }
390        };
391
392        let diagnostics = ext_diagnostics
393            .into_iter()
394            .map(|ed| {
395                let rule_id = if ed.rule_id.is_empty() {
396                    format!("external.{}", self.name)
397                } else {
398                    format!("external.{}.{}", self.name, ed.rule_id)
399                };
400                Diagnostic {
401                    start_byte: ed.start_byte,
402                    end_byte: ed.end_byte,
403                    message: ed.message,
404                    suggestions: ed.suggestions,
405                    rule_id,
406                    severity: ed.severity,
407                    unified_id: String::new(),
408                    confidence: if ed.confidence > 0.0 {
409                        ed.confidence
410                    } else {
411                        0.7
412                    },
413                }
414            })
415            .collect();
416
417        Ok(diagnostics)
418    }
419}
420
421/// A WASM checker plugin loaded via Extism.
422///
423/// The plugin must export a `check` function that accepts a JSON string
424/// `{"text": "...", "language_id": "..."}` and returns a JSON array of
425/// diagnostics matching the `ExternalDiagnostic` schema.
426pub struct WasmEngine {
427    name: String,
428    plugin: Plugin,
429}
430
431// SAFETY: Extism Plugin is not Send by default because it wraps a wasmtime Store
432// which holds raw pointers. However, we only ever access the plugin from a single
433// &mut self call at a time (the Engine trait takes &mut self), so this is safe
434// as long as we don't share across threads simultaneously.
435unsafe impl Send for WasmEngine {}
436
437impl WasmEngine {
438    /// Create a new WASM engine from a `.wasm` file path.
439    pub fn new(name: String, wasm_path: PathBuf) -> Result<Self> {
440        let wasm = Wasm::file(wasm_path);
441        let manifest = Manifest::new([wasm]);
442        let plugin = Plugin::new(&manifest, [], true)?;
443        Ok(Self { name, plugin })
444    }
445
446    /// Create a new WASM engine from raw bytes (useful for testing).
447    pub fn from_bytes(name: String, wasm_bytes: &[u8]) -> Result<Self> {
448        let wasm = Wasm::data(wasm_bytes.to_vec());
449        let manifest = Manifest::new([wasm]);
450        let plugin = Plugin::new(&manifest, [], true)?;
451        Ok(Self { name, plugin })
452    }
453}
454
455#[async_trait::async_trait]
456impl Engine for WasmEngine {
457    fn name(&self) -> &'static str {
458        "wasm"
459    }
460
461    async fn check(&mut self, text: &str, language_id: &str) -> Result<Vec<Diagnostic>> {
462        let request = serde_json::json!({
463            "text": text,
464            "language_id": language_id,
465        });
466        let input = request.to_string();
467
468        let output = match self.plugin.call::<&str, &str>("check", &input) {
469            Ok(result) => result.to_string(),
470            Err(e) => {
471                warn!(plugin = %self.name, "WASM plugin call failed: {e}");
472                return Ok(vec![]);
473            }
474        };
475
476        let ext_diagnostics: Vec<ExternalDiagnostic> = match serde_json::from_str(&output) {
477            Ok(d) => d,
478            Err(e) => {
479                warn!(plugin = %self.name, "Failed to parse WASM plugin output: {e}");
480                return Ok(vec![]);
481            }
482        };
483
484        let diagnostics = ext_diagnostics
485            .into_iter()
486            .map(|ed| {
487                let rule_id = if ed.rule_id.is_empty() {
488                    format!("wasm.{}", self.name)
489                } else {
490                    format!("wasm.{}.{}", self.name, ed.rule_id)
491                };
492                Diagnostic {
493                    start_byte: ed.start_byte,
494                    end_byte: ed.end_byte,
495                    message: ed.message,
496                    suggestions: ed.suggestions,
497                    rule_id,
498                    severity: ed.severity,
499                    unified_id: String::new(),
500                    confidence: if ed.confidence > 0.0 {
501                        ed.confidence
502                    } else {
503                        0.7
504                    },
505                }
506            })
507            .collect();
508
509        Ok(diagnostics)
510    }
511}
512
513/// Discover WASM plugins from a directory (e.g. `.languagecheck/plugins/`).
514/// Returns a list of (name, path) pairs for each `.wasm` file found.
515#[must_use]
516pub fn discover_wasm_plugins(plugin_dir: &std::path::Path) -> Vec<(String, PathBuf)> {
517    let Ok(entries) = std::fs::read_dir(plugin_dir) else {
518        return Vec::new();
519    };
520
521    entries
522        .filter_map(|entry| {
523            let entry = entry.ok()?;
524            let path = entry.path();
525            if path.extension().is_some_and(|e| e == "wasm") {
526                let name = path
527                    .file_stem()
528                    .map(|s| s.to_string_lossy().to_string())
529                    .unwrap_or_default();
530                Some((name, path))
531            } else {
532                None
533            }
534        })
535        .collect()
536}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541
542    #[tokio::test]
543    async fn test_harper_engine() -> Result<()> {
544        let mut engine = HarperEngine::new(&crate::config::HarperConfig::default());
545        let text = "This is an test.";
546        let diagnostics = engine.check(text, "en-US").await?;
547
548        // Harper should find "an test" error
549        assert!(!diagnostics.is_empty());
550
551        Ok(())
552    }
553
554    #[tokio::test]
555    async fn external_engine_with_echo() -> Result<()> {
556        // Use a simple shell command that echoes a valid JSON response
557        let mut engine = ExternalEngine::new(
558            "test-provider".to_string(),
559            "sh".to_string(),
560            vec![
561                "-c".to_string(),
562                r#"cat > /dev/null; echo '[{"start_byte":0,"end_byte":4,"message":"test issue","suggestions":["fix"],"rule_id":"test.rule","severity":2}]'"#.to_string(),
563            ],
564        );
565
566        let diagnostics = engine.check("some text", "markdown").await?;
567        assert_eq!(diagnostics.len(), 1);
568        assert_eq!(diagnostics[0].message, "test issue");
569        assert_eq!(diagnostics[0].rule_id, "external.test-provider.test.rule");
570        assert_eq!(diagnostics[0].suggestions, vec!["fix"]);
571        assert_eq!(diagnostics[0].start_byte, 0);
572        assert_eq!(diagnostics[0].end_byte, 4);
573
574        Ok(())
575    }
576
577    #[tokio::test]
578    async fn external_engine_missing_binary() -> Result<()> {
579        let mut engine = ExternalEngine::new(
580            "nonexistent".to_string(),
581            "/nonexistent/binary".to_string(),
582            vec![],
583        );
584
585        // Should not error, just return empty
586        let diagnostics = engine.check("text", "markdown").await?;
587        assert!(diagnostics.is_empty());
588
589        Ok(())
590    }
591
592    #[tokio::test]
593    async fn external_engine_bad_json_output() -> Result<()> {
594        let mut engine = ExternalEngine::new(
595            "bad-json".to_string(),
596            "echo".to_string(),
597            vec!["not json".to_string()],
598        );
599
600        // Should not error, just return empty
601        let diagnostics = engine.check("text", "markdown").await?;
602        assert!(diagnostics.is_empty());
603
604        Ok(())
605    }
606
607    #[test]
608    fn wasm_engine_invalid_bytes_returns_error() {
609        let result = WasmEngine::from_bytes("bad-plugin".to_string(), b"not a wasm file");
610        assert!(result.is_err());
611    }
612
613    #[test]
614    fn wasm_engine_missing_file_returns_error() {
615        let result = WasmEngine::new(
616            "missing".to_string(),
617            PathBuf::from("/nonexistent/plugin.wasm"),
618        );
619        assert!(result.is_err());
620    }
621
622    #[test]
623    fn discover_wasm_plugins_empty_dir() {
624        let dir = std::env::temp_dir().join("lang_check_test_wasm_empty");
625        let _ = std::fs::remove_dir_all(&dir);
626        std::fs::create_dir_all(&dir).unwrap();
627
628        let plugins = discover_wasm_plugins(&dir);
629        assert!(plugins.is_empty());
630
631        let _ = std::fs::remove_dir_all(&dir);
632    }
633
634    #[test]
635    fn discover_wasm_plugins_finds_wasm_files() {
636        let dir = std::env::temp_dir().join("lang_check_test_wasm_discover");
637        let _ = std::fs::remove_dir_all(&dir);
638        std::fs::create_dir_all(&dir).unwrap();
639
640        // Create fake .wasm files and a non-wasm file
641        std::fs::write(dir.join("checker.wasm"), b"fake").unwrap();
642        std::fs::write(dir.join("linter.wasm"), b"fake").unwrap();
643        std::fs::write(dir.join("readme.txt"), b"not a plugin").unwrap();
644
645        let mut plugins = discover_wasm_plugins(&dir);
646        plugins.sort_by(|a, b| a.0.cmp(&b.0));
647
648        assert_eq!(plugins.len(), 2);
649        assert_eq!(plugins[0].0, "checker");
650        assert_eq!(plugins[1].0, "linter");
651        assert!(plugins[0].1.ends_with("checker.wasm"));
652        assert!(plugins[1].1.ends_with("linter.wasm"));
653
654        let _ = std::fs::remove_dir_all(&dir);
655    }
656
657    #[test]
658    fn discover_wasm_plugins_nonexistent_dir() {
659        let plugins = discover_wasm_plugins(std::path::Path::new("/nonexistent/dir"));
660        assert!(plugins.is_empty());
661    }
662
663    /// Live integration test — requires LT Docker on localhost:8010.
664    /// Run with: `cargo test lt_engine_live -- --ignored --nocapture`
665    #[tokio::test]
666    #[ignore]
667    async fn lt_engine_live() -> Result<()> {
668        // Initialize tracing for visible output
669        let _ = tracing_subscriber::fmt()
670            .with_env_filter("debug")
671            .with_writer(std::io::stderr)
672            .with_target(false)
673            .try_init();
674
675        let mut engine = LanguageToolEngine::new(&crate::config::LanguageToolConfig::default());
676        let text = "This is a sentnce with erors.";
677        let diagnostics = engine.check(text, "markdown").await?;
678
679        println!("LT returned {} diagnostics:", diagnostics.len());
680        for d in &diagnostics {
681            println!(
682                "  [{}-{}] {} (rule: {}, suggestions: {:?})",
683                d.start_byte, d.end_byte, d.message, d.rule_id, d.suggestions
684            );
685        }
686
687        assert!(
688            diagnostics.len() >= 2,
689            "Expected at least 2 spelling errors, got {}",
690            diagnostics.len()
691        );
692        Ok(())
693    }
694
695    #[test]
696    fn lt_response_deserializes_camel_case() {
697        // Real LanguageTool API response (trimmed) — uses camelCase `issueType`
698        let json = r#"{
699            "matches": [{
700                "message": "Possible spelling mistake found.",
701                "offset": 10,
702                "length": 7,
703                "replacements": [{"value": "sentence"}],
704                "rule": {
705                    "id": "MORFOLOGIK_RULE_EN_US",
706                    "description": "Possible spelling mistake",
707                    "issueType": "misspelling",
708                    "category": {"id": "TYPOS", "name": "Possible Typo"}
709                }
710            }]
711        }"#;
712        let res: LTResponse = serde_json::from_str(json).unwrap();
713        assert_eq!(res.matches.len(), 1);
714        assert_eq!(res.matches[0].rule.id, "MORFOLOGIK_RULE_EN_US");
715        assert_eq!(res.matches[0].rule.issue_type, "misspelling");
716        assert_eq!(res.matches[0].offset, 10);
717        assert_eq!(res.matches[0].length, 7);
718        assert_eq!(res.matches[0].replacements[0].value, "sentence");
719    }
720}