use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::LanguageId;
pub const DEFAULT_COMPRESSOR_URL: &str = "http://10.166.1.220:8787";
const MIN_SNIPPET_BYTES: usize = 8;
#[derive(Debug, Clone)]
pub struct CompressorConfig {
pub base_url: String,
pub enabled: bool,
}
impl Default for CompressorConfig {
fn default() -> Self {
Self {
base_url: std::env::var("REDCOMPRESSOR_URL")
.unwrap_or_else(|_| DEFAULT_COMPRESSOR_URL.to_string()),
enabled: true,
}
}
}
#[derive(Debug, Error)]
pub enum CompressError {
#[error("HTTP request failed: {0}")]
Http(#[from] reqwest::Error),
#[error("compress API returned {status}: {message}")]
Api { status: u16, message: String },
#[error("invalid base64 in compress response: {0}")]
BadBase64(#[from] base64::DecodeError),
}
#[derive(Debug, Serialize)]
struct CompressRequest<'a> {
code: &'a str,
language: &'a str,
}
#[derive(Debug, Deserialize)]
struct CompressResponse {
blob_b64: String,
}
#[derive(Debug, Serialize)]
struct DecompressRequest<'a> {
blob_b64: &'a str,
language: &'a str,
}
#[derive(Debug, Serialize, Deserialize)]
struct DecompressResponse {
code: String,
}
#[derive(Debug, Deserialize)]
struct ApiErrorBody {
error: Option<String>,
message: Option<String>,
}
pub struct CompressorClient {
client: Client,
base_url: String,
logged_network_failure: AtomicBool,
logged_dict_missing: AtomicBool,
}
impl CompressorClient {
pub fn new(base_url: impl Into<String>) -> Result<Self, CompressError> {
let client = Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
Ok(Self {
client,
base_url: base_url.into().trim_end_matches('/').to_string(),
logged_network_failure: AtomicBool::new(false),
logged_dict_missing: AtomicBool::new(false),
})
}
pub fn from_config(config: &CompressorConfig) -> Result<Self, CompressError> {
Self::new(config.base_url.clone())
}
pub async fn health_check(&self) -> Result<(), CompressError> {
let url = format!("{}/healthz", self.base_url);
let resp = self.client.get(&url).send().await?;
if resp.status().is_success() {
Ok(())
} else {
Err(CompressError::Api {
status: resp.status().as_u16(),
message: format!("healthz returned {}", resp.status()),
})
}
}
pub async fn compress_code(
&self,
code: &str,
language: LanguageId,
) -> Option<Vec<u8>> {
if code.len() < MIN_SNIPPET_BYTES {
return None;
}
let Some(lang) = compressor_language_name(language) else {
return None;
};
match self.compress_code_raw(code, lang).await {
Ok(blob) => Some(blob),
Err(CompressError::Api { status, message }) => {
if status == 422 && message.contains("dict_missing") {
if !self.logged_dict_missing.swap(true, Ordering::Relaxed) {
eprintln!(
"RedCompressor: dict_missing for language `{lang}` — skipping code_bytes (further warnings suppressed)"
);
}
} else if status == 400 && message.contains("unknown_language") {
eprintln!("RedCompressor: unknown language `{lang}`");
} else {
eprintln!("RedCompressor: API error {status}: {message}");
}
None
}
Err(e) => {
if !self.logged_network_failure.swap(true, Ordering::Relaxed) {
eprintln!("RedCompressor: request failed ({e}) — skipping code_bytes (further warnings suppressed)");
}
None
}
}
}
async fn compress_code_raw(&self, code: &str, language: &str) -> Result<Vec<u8>, CompressError> {
let url = format!("{}/v1/compress", self.base_url);
let body = CompressRequest { code, language };
let resp = self.client.post(&url).json(&body).send().await?;
let status = resp.status();
if status.is_success() {
let parsed: CompressResponse = resp.json().await?;
return B64.decode(parsed.blob_b64).map_err(CompressError::BadBase64);
}
let text = resp.text().await.unwrap_or_default();
let message = serde_json::from_str::<ApiErrorBody>(&text)
.ok()
.and_then(|e| e.message.or(e.error))
.unwrap_or(text);
Err(CompressError::Api {
status: status.as_u16(),
message,
})
}
pub async fn decompress_code(
&self,
blob: &[u8],
language: &str,
) -> Result<String, CompressError> {
if blob.is_empty() {
return Err(CompressError::Api {
status: 400,
message: "empty blob".into(),
});
}
let url = format!("{}/v1/decompress", self.base_url);
let blob_b64 = B64.encode(blob);
let body = DecompressRequest {
blob_b64: &blob_b64,
language,
};
let resp = self.client.post(&url).json(&body).send().await?;
let status = resp.status();
if status.is_success() {
let parsed: DecompressResponse = resp.json().await?;
return Ok(parsed.code);
}
let text = resp.text().await.unwrap_or_default();
let message = serde_json::from_str::<ApiErrorBody>(&text)
.ok()
.and_then(|e| e.message.or(e.error))
.unwrap_or(text);
Err(CompressError::Api {
status: status.as_u16(),
message,
})
}
}
pub fn compressor_language_name(language: LanguageId) -> Option<&'static str> {
match language {
LanguageId::Java => Some("java"),
LanguageId::JavaScript => Some("javascript"),
LanguageId::TypeScript | LanguageId::Tsx => Some("typescript"),
LanguageId::Python => Some("python"),
LanguageId::Rust => Some("rust"),
LanguageId::Go => Some("go"),
LanguageId::Erlang => Some("erlang"),
LanguageId::CSharp => Some("csharp"),
}
}
pub async fn compress_snippet(
source: &str,
span: Option<(usize, usize)>,
language: LanguageId,
client: &CompressorClient,
) -> Option<Vec<u8>> {
let (lo, hi) = span?;
let lo = lo.min(source.len());
let hi = hi.min(source.len());
if lo >= hi {
return None;
}
client.compress_code(&source[lo..hi], language).await
}
pub async fn decompress_code_bytes(
blob: &[u8],
language: LanguageId,
client: &CompressorClient,
) -> Option<String> {
let lang = compressor_language_name(language)?;
client.decompress_code(blob, lang).await.ok()
}
pub fn language_id_from_ir_string(s: &str) -> Option<LanguageId> {
match s.to_lowercase().as_str() {
"java" => Some(LanguageId::Java),
"javascript" | "js" => Some(LanguageId::JavaScript),
"typescript" | "ts" => Some(LanguageId::TypeScript),
"tsx" => Some(LanguageId::Tsx),
"python" | "py" => Some(LanguageId::Python),
"rust" | "rs" => Some(LanguageId::Rust),
"go" | "golang" => Some(LanguageId::Go),
"erlang" | "erl" => Some(LanguageId::Erlang),
"csharp" | "c_sharp" | "cs" => Some(LanguageId::CSharp),
_ => None,
}
}
pub fn compressor_language_from_ir_string(s: &str) -> Option<&'static str> {
language_id_from_ir_string(s).and_then(compressor_language_name)
}
pub async fn compress_full_source(
source: &str,
language: LanguageId,
client: &CompressorClient,
) -> Option<Vec<u8>> {
if source.is_empty() {
return None;
}
client.compress_code(source, language).await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn maps_all_parser_languages_including_erlang() {
assert_eq!(compressor_language_name(LanguageId::Erlang), Some("erlang"));
assert_eq!(compressor_language_name(LanguageId::CSharp), Some("csharp"));
assert_eq!(compressor_language_name(LanguageId::Tsx), Some("typescript"));
assert_eq!(compressor_language_name(LanguageId::Java), Some("java"));
}
#[test]
fn compress_request_json_matches_script_format() {
let req = CompressRequest {
code: "def f():\n return 1\n",
language: "python",
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains(r#""language":"python""#));
assert!(json.contains(r#"\n"#));
}
#[test]
fn decodes_compress_response_blob() {
let sample = b"hello world".to_vec();
let b64 = B64.encode(&sample);
let resp = CompressResponse { blob_b64: b64 };
let decoded = B64.decode(resp.blob_b64).unwrap();
assert_eq!(decoded, sample);
}
#[test]
fn decompress_request_json_matches_format() {
let blob_b64 = B64.encode(b"compressed");
let req = DecompressRequest {
blob_b64: &blob_b64,
language: "rust",
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains(r#""language":"rust""#));
assert!(json.contains("blob_b64"));
}
#[test]
fn parses_decompress_response_code() {
let resp = DecompressResponse {
code: "fn main() {}".into(),
};
let json = serde_json::to_string(&resp).unwrap();
let parsed: DecompressResponse = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.code, "fn main() {}");
}
#[test]
fn language_id_from_ir_string_maps_csharp_variants() {
assert_eq!(
language_id_from_ir_string("c_sharp"),
Some(LanguageId::CSharp)
);
assert_eq!(language_id_from_ir_string("rust"), Some(LanguageId::Rust));
}
}