1use std::sync::atomic::{AtomicBool, Ordering};
2use std::time::Duration;
3
4use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
5use reqwest::Client;
6use serde::{Deserialize, Serialize};
7use thiserror::Error;
8
9use crate::LanguageId;
10
11pub const DEFAULT_COMPRESSOR_URL: &str = "http://10.166.1.220:8787";
12const MIN_SNIPPET_BYTES: usize = 8;
13
14#[derive(Debug, Clone)]
15pub struct CompressorConfig {
16 pub base_url: String,
17 pub enabled: bool,
18}
19
20impl Default for CompressorConfig {
21 fn default() -> Self {
22 Self {
23 base_url: std::env::var("REDCOMPRESSOR_URL")
24 .unwrap_or_else(|_| DEFAULT_COMPRESSOR_URL.to_string()),
25 enabled: true,
26 }
27 }
28}
29
30#[derive(Debug, Error)]
31pub enum CompressError {
32 #[error("HTTP request failed: {0}")]
33 Http(#[from] reqwest::Error),
34 #[error("compress API returned {status}: {message}")]
35 Api { status: u16, message: String },
36 #[error("invalid base64 in compress response: {0}")]
37 BadBase64(#[from] base64::DecodeError),
38}
39
40#[derive(Debug, Serialize)]
41struct CompressRequest<'a> {
42 code: &'a str,
43 language: &'a str,
44}
45
46#[derive(Debug, Deserialize)]
47struct CompressResponse {
48 blob_b64: String,
49}
50
51#[derive(Debug, Serialize)]
52struct DecompressRequest<'a> {
53 blob_b64: &'a str,
54 language: &'a str,
55}
56
57#[derive(Debug, Serialize, Deserialize)]
58struct DecompressResponse {
59 code: String,
60}
61
62#[derive(Debug, Deserialize)]
63struct ApiErrorBody {
64 error: Option<String>,
65 message: Option<String>,
66}
67
68pub struct CompressorClient {
70 client: Client,
71 base_url: String,
72 logged_network_failure: AtomicBool,
73 logged_dict_missing: AtomicBool,
74}
75
76impl CompressorClient {
77 pub fn new(base_url: impl Into<String>) -> Result<Self, CompressError> {
78 let client = Client::builder()
79 .timeout(Duration::from_secs(30))
80 .build()?;
81 Ok(Self {
82 client,
83 base_url: base_url.into().trim_end_matches('/').to_string(),
84 logged_network_failure: AtomicBool::new(false),
85 logged_dict_missing: AtomicBool::new(false),
86 })
87 }
88
89 pub fn from_config(config: &CompressorConfig) -> Result<Self, CompressError> {
90 Self::new(config.base_url.clone())
91 }
92
93 pub async fn health_check(&self) -> Result<(), CompressError> {
94 let url = format!("{}/healthz", self.base_url);
95 let resp = self.client.get(&url).send().await?;
96 if resp.status().is_success() {
97 Ok(())
98 } else {
99 Err(CompressError::Api {
100 status: resp.status().as_u16(),
101 message: format!("healthz returned {}", resp.status()),
102 })
103 }
104 }
105
106 pub async fn compress_code(
107 &self,
108 code: &str,
109 language: LanguageId,
110 ) -> Option<Vec<u8>> {
111 if code.len() < MIN_SNIPPET_BYTES {
112 return None;
113 }
114 let Some(lang) = compressor_language_name(language) else {
115 return None;
116 };
117
118 match self.compress_code_raw(code, lang).await {
119 Ok(blob) => Some(blob),
120 Err(CompressError::Api { status, message }) => {
121 if status == 422 && message.contains("dict_missing") {
122 if !self.logged_dict_missing.swap(true, Ordering::Relaxed) {
123 eprintln!(
124 "RedCompressor: dict_missing for language `{lang}` — skipping code_bytes (further warnings suppressed)"
125 );
126 }
127 } else if status == 400 && message.contains("unknown_language") {
128 eprintln!("RedCompressor: unknown language `{lang}`");
129 } else {
130 eprintln!("RedCompressor: API error {status}: {message}");
131 }
132 None
133 }
134 Err(e) => {
135 if !self.logged_network_failure.swap(true, Ordering::Relaxed) {
136 eprintln!("RedCompressor: request failed ({e}) — skipping code_bytes (further warnings suppressed)");
137 }
138 None
139 }
140 }
141 }
142
143 async fn compress_code_raw(&self, code: &str, language: &str) -> Result<Vec<u8>, CompressError> {
144 let url = format!("{}/v1/compress", self.base_url);
145 let body = CompressRequest { code, language };
146 let resp = self.client.post(&url).json(&body).send().await?;
147
148 let status = resp.status();
149 if status.is_success() {
150 let parsed: CompressResponse = resp.json().await?;
151 return B64.decode(parsed.blob_b64).map_err(CompressError::BadBase64);
152 }
153
154 let text = resp.text().await.unwrap_or_default();
155 let message = serde_json::from_str::<ApiErrorBody>(&text)
156 .ok()
157 .and_then(|e| e.message.or(e.error))
158 .unwrap_or(text);
159 Err(CompressError::Api {
160 status: status.as_u16(),
161 message,
162 })
163 }
164
165 pub async fn decompress_code(
167 &self,
168 blob: &[u8],
169 language: &str,
170 ) -> Result<String, CompressError> {
171 if blob.is_empty() {
172 return Err(CompressError::Api {
173 status: 400,
174 message: "empty blob".into(),
175 });
176 }
177 let url = format!("{}/v1/decompress", self.base_url);
178 let blob_b64 = B64.encode(blob);
179 let body = DecompressRequest {
180 blob_b64: &blob_b64,
181 language,
182 };
183 let resp = self.client.post(&url).json(&body).send().await?;
184
185 let status = resp.status();
186 if status.is_success() {
187 let parsed: DecompressResponse = resp.json().await?;
188 return Ok(parsed.code);
189 }
190
191 let text = resp.text().await.unwrap_or_default();
192 let message = serde_json::from_str::<ApiErrorBody>(&text)
193 .ok()
194 .and_then(|e| e.message.or(e.error))
195 .unwrap_or(text);
196 Err(CompressError::Api {
197 status: status.as_u16(),
198 message,
199 })
200 }
201}
202
203pub fn compressor_language_name(language: LanguageId) -> Option<&'static str> {
205 match language {
206 LanguageId::Java => Some("java"),
207 LanguageId::JavaScript => Some("javascript"),
208 LanguageId::TypeScript | LanguageId::Tsx => Some("typescript"),
209 LanguageId::Python => Some("python"),
210 LanguageId::Rust => Some("rust"),
211 LanguageId::Go => Some("go"),
212 LanguageId::Erlang => Some("erlang"),
213 LanguageId::CSharp => Some("csharp"),
214 }
215}
216
217pub async fn compress_snippet(
219 source: &str,
220 span: Option<(usize, usize)>,
221 language: LanguageId,
222 client: &CompressorClient,
223) -> Option<Vec<u8>> {
224 let (lo, hi) = span?;
225 let lo = lo.min(source.len());
226 let hi = hi.min(source.len());
227 if lo >= hi {
228 return None;
229 }
230 client.compress_code(&source[lo..hi], language).await
231}
232
233pub async fn decompress_code_bytes(
235 blob: &[u8],
236 language: LanguageId,
237 client: &CompressorClient,
238) -> Option<String> {
239 let lang = compressor_language_name(language)?;
240 client.decompress_code(blob, lang).await.ok()
241}
242
243pub fn language_id_from_ir_string(s: &str) -> Option<LanguageId> {
245 match s.to_lowercase().as_str() {
246 "java" => Some(LanguageId::Java),
247 "javascript" | "js" => Some(LanguageId::JavaScript),
248 "typescript" | "ts" => Some(LanguageId::TypeScript),
249 "tsx" => Some(LanguageId::Tsx),
250 "python" | "py" => Some(LanguageId::Python),
251 "rust" | "rs" => Some(LanguageId::Rust),
252 "go" | "golang" => Some(LanguageId::Go),
253 "erlang" | "erl" => Some(LanguageId::Erlang),
254 "csharp" | "c_sharp" | "cs" => Some(LanguageId::CSharp),
255 _ => None,
256 }
257}
258
259pub fn compressor_language_from_ir_string(s: &str) -> Option<&'static str> {
261 language_id_from_ir_string(s).and_then(compressor_language_name)
262}
263
264pub async fn compress_full_source(
266 source: &str,
267 language: LanguageId,
268 client: &CompressorClient,
269) -> Option<Vec<u8>> {
270 if source.is_empty() {
271 return None;
272 }
273 client.compress_code(source, language).await
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn maps_all_parser_languages_including_erlang() {
282 assert_eq!(compressor_language_name(LanguageId::Erlang), Some("erlang"));
283 assert_eq!(compressor_language_name(LanguageId::CSharp), Some("csharp"));
284 assert_eq!(compressor_language_name(LanguageId::Tsx), Some("typescript"));
285 assert_eq!(compressor_language_name(LanguageId::Java), Some("java"));
286 }
287
288 #[test]
289 fn compress_request_json_matches_script_format() {
290 let req = CompressRequest {
291 code: "def f():\n return 1\n",
292 language: "python",
293 };
294 let json = serde_json::to_string(&req).unwrap();
295 assert!(json.contains(r#""language":"python""#));
296 assert!(json.contains(r#"\n"#));
297 }
298
299 #[test]
300 fn decodes_compress_response_blob() {
301 let sample = b"hello world".to_vec();
302 let b64 = B64.encode(&sample);
303 let resp = CompressResponse { blob_b64: b64 };
304 let decoded = B64.decode(resp.blob_b64).unwrap();
305 assert_eq!(decoded, sample);
306 }
307
308 #[test]
309 fn decompress_request_json_matches_format() {
310 let blob_b64 = B64.encode(b"compressed");
311 let req = DecompressRequest {
312 blob_b64: &blob_b64,
313 language: "rust",
314 };
315 let json = serde_json::to_string(&req).unwrap();
316 assert!(json.contains(r#""language":"rust""#));
317 assert!(json.contains("blob_b64"));
318 }
319
320 #[test]
321 fn parses_decompress_response_code() {
322 let resp = DecompressResponse {
323 code: "fn main() {}".into(),
324 };
325 let json = serde_json::to_string(&resp).unwrap();
326 let parsed: DecompressResponse = serde_json::from_str(&json).unwrap();
327 assert_eq!(parsed.code, "fn main() {}");
328 }
329
330 #[test]
331 fn language_id_from_ir_string_maps_csharp_variants() {
332 assert_eq!(
333 language_id_from_ir_string("c_sharp"),
334 Some(LanguageId::CSharp)
335 );
336 assert_eq!(language_id_from_ir_string("rust"), Some(LanguageId::Rust));
337 }
338}