Skip to main content

lang_check/engines/
proselint.rs

1use crate::checker::{Diagnostic, Severity};
2use anyhow::Result;
3use serde::Deserialize;
4use std::collections::HashMap;
5use tracing::{debug, warn};
6
7use super::Engine;
8
9pub struct ProselintEngine {
10    config_path: Option<String>,
11}
12
13impl ProselintEngine {
14    #[must_use]
15    pub const fn new(config_path: Option<String>) -> Self {
16        Self { config_path }
17    }
18}
19
20/// Top-level JSON output from `proselint check -o json`.
21#[derive(Deserialize)]
22struct ProselintOutput {
23    result: HashMap<String, ProselintFileResult>,
24}
25
26/// Per-file result — either diagnostics or an error.
27#[derive(Deserialize)]
28#[serde(untagged)]
29enum ProselintFileResult {
30    Ok {
31        diagnostics: Vec<ProselintDiagnostic>,
32    },
33    Err {
34        error: ProselintError,
35    },
36}
37
38#[derive(Deserialize)]
39struct ProselintDiagnostic {
40    check_path: String,
41    message: String,
42    /// Character offsets [start, end] in padded content (shifted by +1).
43    span: (usize, usize),
44    /// Suggested replacement text, or null.
45    replacements: Option<String>,
46}
47
48#[derive(Deserialize)]
49struct ProselintError {
50    message: String,
51}
52
53/// Convert proselint's character-offset span (1-based due to `"\n"` padding)
54/// to byte offsets in the original text.
55///
56/// Proselint internally pads content as `"\n" + content + "\n"`, so all span
57/// values are shifted by +1 character. We subtract 1 to get the offset into
58/// the original text, then convert from char offset to byte offset.
59#[allow(clippy::cast_possible_truncation)]
60fn char_span_to_byte_range(text: &str, span: (usize, usize)) -> (u32, u32) {
61    // Subtract the 1-char padding offset
62    let char_start = span.0.saturating_sub(1);
63    let char_end = span.1.saturating_sub(1);
64
65    let mut byte_start = text.len();
66    let mut byte_end = text.len();
67
68    for (i, (byte_idx, _)) in text.char_indices().enumerate() {
69        if i == char_start {
70            byte_start = byte_idx;
71        }
72        if i == char_end {
73            byte_end = byte_idx;
74            break;
75        }
76    }
77
78    (byte_start as u32, byte_end as u32)
79}
80
81#[async_trait::async_trait]
82impl Engine for ProselintEngine {
83    fn name(&self) -> &'static str {
84        "proselint"
85    }
86
87    fn supported_languages(&self) -> Vec<&'static str> {
88        vec!["en"]
89    }
90
91    async fn check(&mut self, text: &str, _language_id: &str) -> Result<Vec<Diagnostic>> {
92        use tokio::io::AsyncWriteExt;
93        use tokio::process::Command;
94
95        let mut cmd = Command::new("proselint");
96        cmd.arg("check").arg("-o").arg("json");
97
98        if let Some(cfg) = &self.config_path {
99            cmd.arg("--config").arg(cfg);
100        }
101
102        cmd.stdin(std::process::Stdio::piped())
103            .stdout(std::process::Stdio::piped())
104            .stderr(std::process::Stdio::piped());
105
106        let output = match cmd.spawn() {
107            Ok(mut child) => {
108                if let Some(mut stdin) = child.stdin.take() {
109                    let _ = stdin.write_all(text.as_bytes()).await;
110                    let _ = stdin.shutdown().await;
111                }
112                child.wait_with_output().await?
113            }
114            Err(e) => {
115                warn!("Failed to spawn proselint: {e}");
116                return Ok(vec![]);
117            }
118        };
119
120        // Exit code 0 = clean, 1 = found errors (both normal)
121        // Exit code >= 2 = actual error
122        let code = output.status.code().unwrap_or(4);
123        if code >= 2 {
124            let stderr = String::from_utf8_lossy(&output.stderr);
125            warn!(code, stderr = stderr.trim(), "Proselint error");
126            return Ok(vec![]);
127        }
128
129        let stdout = String::from_utf8_lossy(&output.stdout);
130        if stdout.trim().is_empty() {
131            return Ok(vec![]);
132        }
133
134        // Proselint may output JSON twice (a known quirk); use a streaming
135        // deserializer to parse only the first valid JSON object.
136        let mut de = serde_json::Deserializer::from_str(&stdout).into_iter::<ProselintOutput>();
137        let parsed: ProselintOutput = match de.next() {
138            Some(Ok(o)) => o,
139            Some(Err(e)) => {
140                warn!("Failed to parse proselint JSON: {e}");
141                debug!(stdout = %stdout, "Raw proselint output");
142                return Ok(vec![]);
143            }
144            None => return Ok(vec![]),
145        };
146
147        let mut diagnostics = Vec::new();
148        for file_result in parsed.result.into_values() {
149            match file_result {
150                ProselintFileResult::Ok { diagnostics: diags } => {
151                    for d in diags {
152                        let (start_byte, end_byte) = char_span_to_byte_range(text, d.span);
153                        let suggestions = d.replacements.map(|r| vec![r]).unwrap_or_default();
154
155                        diagnostics.push(Diagnostic {
156                            start_byte,
157                            end_byte,
158                            message: d.message,
159                            suggestions,
160                            rule_id: format!("proselint.{}", d.check_path),
161                            severity: Severity::Warning as i32,
162                            unified_id: String::new(),
163                            confidence: 0.7,
164                        });
165                    }
166                }
167                ProselintFileResult::Err { error } => {
168                    warn!(msg = error.message, "Proselint reported a file error");
169                }
170            }
171        }
172
173        Ok(diagnostics)
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    #[test]
182    fn char_span_basic() {
183        let text = "Hello world";
184        // proselint span would be (7, 12) for "world" (1-based offset from padding)
185        let (start, end) = char_span_to_byte_range(text, (7, 12));
186        assert_eq!(start, 6);
187        assert_eq!(end, 11);
188        assert_eq!(&text[start as usize..end as usize], "world");
189    }
190
191    #[test]
192    fn char_span_start_of_text() {
193        let text = "Hello";
194        // proselint span (1, 6) for "Hello" (padded +1)
195        let (start, end) = char_span_to_byte_range(text, (1, 6));
196        assert_eq!(start, 0);
197        assert_eq!(end, 5);
198        assert_eq!(&text[start as usize..end as usize], "Hello");
199    }
200
201    #[test]
202    fn char_span_unicode() {
203        let text = "café latte";
204        // "latte" starts at char index 5, span would be (6, 11) with padding
205        let (start, end) = char_span_to_byte_range(text, (6, 11));
206        assert_eq!(&text[start as usize..end as usize], "latte");
207    }
208
209    #[test]
210    fn char_span_clamped() {
211        let text = "short";
212        let (start, end) = char_span_to_byte_range(text, (1, 100));
213        assert_eq!(start, 0);
214        assert_eq!(end as usize, text.len());
215    }
216
217    #[test]
218    fn proselint_diagnostic_deserializes() {
219        let json = r#"{
220            "check_path": "uncomparables",
221            "message": "Comparison of an uncomparable: 'very unique'.",
222            "span": [10, 21],
223            "replacements": "unique",
224            "pos": [1, 9]
225        }"#;
226        let d: ProselintDiagnostic = serde_json::from_str(json).unwrap();
227        assert_eq!(d.check_path, "uncomparables");
228        assert_eq!(d.span, (10, 21));
229        assert_eq!(d.replacements.as_deref(), Some("unique"));
230    }
231
232    #[test]
233    fn proselint_diagnostic_null_replacements() {
234        let json = r#"{
235            "check_path": "hedging",
236            "message": "Hedging: 'I think'.",
237            "span": [1, 8],
238            "replacements": null,
239            "pos": [1, 0]
240        }"#;
241        let d: ProselintDiagnostic = serde_json::from_str(json).unwrap();
242        assert!(d.replacements.is_none());
243    }
244
245    #[test]
246    fn proselint_full_output_deserializes() {
247        let json = r#"{
248            "result": {
249                "<stdin>": {
250                    "diagnostics": [
251                        {
252                            "check_path": "uncomparables",
253                            "message": "Comparison of an uncomparable.",
254                            "span": [10, 21],
255                            "replacements": "unique",
256                            "pos": [1, 9]
257                        }
258                    ]
259                }
260            }
261        }"#;
262        let output: ProselintOutput = serde_json::from_str(json).unwrap();
263        assert_eq!(output.result.len(), 1);
264        match &output.result["<stdin>"] {
265            ProselintFileResult::Ok { diagnostics } => {
266                assert_eq!(diagnostics.len(), 1);
267                assert_eq!(diagnostics[0].check_path, "uncomparables");
268            }
269            ProselintFileResult::Err { .. } => panic!("expected Ok"),
270        }
271    }
272
273    #[test]
274    fn proselint_error_result_deserializes() {
275        let json = r#"{
276            "result": {
277                "<stdin>": {
278                    "error": {
279                        "code": -31997,
280                        "message": "Some error occurred"
281                    }
282                }
283            }
284        }"#;
285        let output: ProselintOutput = serde_json::from_str(json).unwrap();
286        match &output.result["<stdin>"] {
287            ProselintFileResult::Err { error } => {
288                assert_eq!(error.message, "Some error occurred");
289            }
290            ProselintFileResult::Ok { .. } => panic!("expected Err"),
291        }
292    }
293
294    #[tokio::test]
295    async fn proselint_engine_missing_binary() -> Result<()> {
296        let mut engine = ProselintEngine::new(None);
297        let result = engine.check("test text", "en-US").await;
298        assert!(result.is_ok());
299        Ok(())
300    }
301
302    /// Live integration test — requires `proselint` installed.
303    /// Run with: `cargo test proselint_engine_live -- --ignored --nocapture`
304    #[tokio::test]
305    #[ignore]
306    async fn proselint_engine_live() -> Result<()> {
307        let mut engine = ProselintEngine::new(None);
308        let text = "This is very unique and extremely obvious.";
309        let diagnostics = engine.check(text, "en-US").await?;
310
311        println!("Proselint returned {} diagnostics:", diagnostics.len());
312        for d in &diagnostics {
313            println!(
314                "  [{}-{}] {} (rule: {}, suggestions: {:?})",
315                d.start_byte, d.end_byte, d.message, d.rule_id, d.suggestions
316            );
317        }
318
319        assert!(
320            !diagnostics.is_empty(),
321            "Expected at least 1 diagnostic from proselint"
322        );
323        assert!(diagnostics[0].rule_id.starts_with("proselint."));
324        Ok(())
325    }
326}