1use crate::checker::{Diagnostic, Severity};
2use anyhow::Result;
3use serde::Deserialize;
4use std::collections::HashMap;
5use tracing::{debug, warn};
6
7use super::Engine;
8
9pub struct ProselintEngine {
10 config_path: Option<String>,
11}
12
13impl ProselintEngine {
14 #[must_use]
15 pub const fn new(config_path: Option<String>) -> Self {
16 Self { config_path }
17 }
18}
19
20#[derive(Deserialize)]
22struct ProselintOutput {
23 result: HashMap<String, ProselintFileResult>,
24}
25
26#[derive(Deserialize)]
28#[serde(untagged)]
29enum ProselintFileResult {
30 Ok {
31 diagnostics: Vec<ProselintDiagnostic>,
32 },
33 Err {
34 error: ProselintError,
35 },
36}
37
38#[derive(Deserialize)]
39struct ProselintDiagnostic {
40 check_path: String,
41 message: String,
42 span: (usize, usize),
44 replacements: Option<String>,
46}
47
48#[derive(Deserialize)]
49struct ProselintError {
50 message: String,
51}
52
53#[allow(clippy::cast_possible_truncation)]
60fn char_span_to_byte_range(text: &str, span: (usize, usize)) -> (u32, u32) {
61 let char_start = span.0.saturating_sub(1);
63 let char_end = span.1.saturating_sub(1);
64
65 let mut byte_start = text.len();
66 let mut byte_end = text.len();
67
68 for (i, (byte_idx, _)) in text.char_indices().enumerate() {
69 if i == char_start {
70 byte_start = byte_idx;
71 }
72 if i == char_end {
73 byte_end = byte_idx;
74 break;
75 }
76 }
77
78 (byte_start as u32, byte_end as u32)
79}
80
81#[async_trait::async_trait]
82impl Engine for ProselintEngine {
83 fn name(&self) -> &'static str {
84 "proselint"
85 }
86
87 fn supported_languages(&self) -> Vec<&'static str> {
88 vec!["en"]
89 }
90
91 async fn check(&mut self, text: &str, _language_id: &str) -> Result<Vec<Diagnostic>> {
92 use tokio::io::AsyncWriteExt;
93 use tokio::process::Command;
94
95 let mut cmd = Command::new("proselint");
96 cmd.arg("check").arg("-o").arg("json");
97
98 if let Some(cfg) = &self.config_path {
99 cmd.arg("--config").arg(cfg);
100 }
101
102 cmd.stdin(std::process::Stdio::piped())
103 .stdout(std::process::Stdio::piped())
104 .stderr(std::process::Stdio::piped());
105
106 let output = match cmd.spawn() {
107 Ok(mut child) => {
108 if let Some(mut stdin) = child.stdin.take() {
109 let _ = stdin.write_all(text.as_bytes()).await;
110 let _ = stdin.shutdown().await;
111 }
112 child.wait_with_output().await?
113 }
114 Err(e) => {
115 warn!("Failed to spawn proselint: {e}");
116 return Ok(vec![]);
117 }
118 };
119
120 let code = output.status.code().unwrap_or(4);
123 if code >= 2 {
124 let stderr = String::from_utf8_lossy(&output.stderr);
125 warn!(code, stderr = stderr.trim(), "Proselint error");
126 return Ok(vec![]);
127 }
128
129 let stdout = String::from_utf8_lossy(&output.stdout);
130 if stdout.trim().is_empty() {
131 return Ok(vec![]);
132 }
133
134 let mut de = serde_json::Deserializer::from_str(&stdout).into_iter::<ProselintOutput>();
137 let parsed: ProselintOutput = match de.next() {
138 Some(Ok(o)) => o,
139 Some(Err(e)) => {
140 warn!("Failed to parse proselint JSON: {e}");
141 debug!(stdout = %stdout, "Raw proselint output");
142 return Ok(vec![]);
143 }
144 None => return Ok(vec![]),
145 };
146
147 let mut diagnostics = Vec::new();
148 for file_result in parsed.result.into_values() {
149 match file_result {
150 ProselintFileResult::Ok { diagnostics: diags } => {
151 for d in diags {
152 let (start_byte, end_byte) = char_span_to_byte_range(text, d.span);
153 let suggestions = d.replacements.map(|r| vec![r]).unwrap_or_default();
154
155 diagnostics.push(Diagnostic {
156 start_byte,
157 end_byte,
158 message: d.message,
159 suggestions,
160 rule_id: format!("proselint.{}", d.check_path),
161 severity: Severity::Warning as i32,
162 unified_id: String::new(),
163 confidence: 0.7,
164 });
165 }
166 }
167 ProselintFileResult::Err { error } => {
168 warn!(msg = error.message, "Proselint reported a file error");
169 }
170 }
171 }
172
173 Ok(diagnostics)
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 #[test]
182 fn char_span_basic() {
183 let text = "Hello world";
184 let (start, end) = char_span_to_byte_range(text, (7, 12));
186 assert_eq!(start, 6);
187 assert_eq!(end, 11);
188 assert_eq!(&text[start as usize..end as usize], "world");
189 }
190
191 #[test]
192 fn char_span_start_of_text() {
193 let text = "Hello";
194 let (start, end) = char_span_to_byte_range(text, (1, 6));
196 assert_eq!(start, 0);
197 assert_eq!(end, 5);
198 assert_eq!(&text[start as usize..end as usize], "Hello");
199 }
200
201 #[test]
202 fn char_span_unicode() {
203 let text = "café latte";
204 let (start, end) = char_span_to_byte_range(text, (6, 11));
206 assert_eq!(&text[start as usize..end as usize], "latte");
207 }
208
209 #[test]
210 fn char_span_clamped() {
211 let text = "short";
212 let (start, end) = char_span_to_byte_range(text, (1, 100));
213 assert_eq!(start, 0);
214 assert_eq!(end as usize, text.len());
215 }
216
217 #[test]
218 fn proselint_diagnostic_deserializes() {
219 let json = r#"{
220 "check_path": "uncomparables",
221 "message": "Comparison of an uncomparable: 'very unique'.",
222 "span": [10, 21],
223 "replacements": "unique",
224 "pos": [1, 9]
225 }"#;
226 let d: ProselintDiagnostic = serde_json::from_str(json).unwrap();
227 assert_eq!(d.check_path, "uncomparables");
228 assert_eq!(d.span, (10, 21));
229 assert_eq!(d.replacements.as_deref(), Some("unique"));
230 }
231
232 #[test]
233 fn proselint_diagnostic_null_replacements() {
234 let json = r#"{
235 "check_path": "hedging",
236 "message": "Hedging: 'I think'.",
237 "span": [1, 8],
238 "replacements": null,
239 "pos": [1, 0]
240 }"#;
241 let d: ProselintDiagnostic = serde_json::from_str(json).unwrap();
242 assert!(d.replacements.is_none());
243 }
244
245 #[test]
246 fn proselint_full_output_deserializes() {
247 let json = r#"{
248 "result": {
249 "<stdin>": {
250 "diagnostics": [
251 {
252 "check_path": "uncomparables",
253 "message": "Comparison of an uncomparable.",
254 "span": [10, 21],
255 "replacements": "unique",
256 "pos": [1, 9]
257 }
258 ]
259 }
260 }
261 }"#;
262 let output: ProselintOutput = serde_json::from_str(json).unwrap();
263 assert_eq!(output.result.len(), 1);
264 match &output.result["<stdin>"] {
265 ProselintFileResult::Ok { diagnostics } => {
266 assert_eq!(diagnostics.len(), 1);
267 assert_eq!(diagnostics[0].check_path, "uncomparables");
268 }
269 ProselintFileResult::Err { .. } => panic!("expected Ok"),
270 }
271 }
272
273 #[test]
274 fn proselint_error_result_deserializes() {
275 let json = r#"{
276 "result": {
277 "<stdin>": {
278 "error": {
279 "code": -31997,
280 "message": "Some error occurred"
281 }
282 }
283 }
284 }"#;
285 let output: ProselintOutput = serde_json::from_str(json).unwrap();
286 match &output.result["<stdin>"] {
287 ProselintFileResult::Err { error } => {
288 assert_eq!(error.message, "Some error occurred");
289 }
290 ProselintFileResult::Ok { .. } => panic!("expected Err"),
291 }
292 }
293
294 #[tokio::test]
295 async fn proselint_engine_missing_binary() -> Result<()> {
296 let mut engine = ProselintEngine::new(None);
297 let result = engine.check("test text", "en-US").await;
298 assert!(result.is_ok());
299 Ok(())
300 }
301
302 #[tokio::test]
305 #[ignore]
306 async fn proselint_engine_live() -> Result<()> {
307 let mut engine = ProselintEngine::new(None);
308 let text = "This is very unique and extremely obvious.";
309 let diagnostics = engine.check(text, "en-US").await?;
310
311 println!("Proselint returned {} diagnostics:", diagnostics.len());
312 for d in &diagnostics {
313 println!(
314 " [{}-{}] {} (rule: {}, suggestions: {:?})",
315 d.start_byte, d.end_byte, d.message, d.rule_id, d.suggestions
316 );
317 }
318
319 assert!(
320 !diagnostics.is_empty(),
321 "Expected at least 1 diagnostic from proselint"
322 );
323 assert!(diagnostics[0].rule_id.starts_with("proselint."));
324 Ok(())
325 }
326}