#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct WineFeatures {
pub fixed_acidity: f32,
pub volatile_acidity: f32,
pub citric_acid: f32,
pub residual_sugar: f32,
pub chlorides: f32,
pub free_sulfur_dioxide: f32,
pub total_sulfur_dioxide: f32,
pub density: f32,
pub ph: f32,
pub sulphates: f32,
pub alcohol: f32,
}
impl WineFeatures {
fn normalize(&self) -> Vec<f32> {
vec![
(self.fixed_acidity - 4.0) / 12.0,
(self.volatile_acidity - 0.1) / 1.5,
self.citric_acid,
(self.residual_sugar - 0.9) / 14.1,
(self.chlorides - 0.01) / 0.59,
(self.free_sulfur_dioxide - 1.0) / 71.0,
(self.total_sulfur_dioxide - 6.0) / 283.0,
(self.density - 0.99) / 0.05,
(self.ph - 2.7) / 1.3,
(self.sulphates - 0.3) / 1.7,
(self.alcohol - 8.0) / 7.0,
]
}
}
#[derive(Debug, serde::Deserialize)]
pub struct Request {
#[serde(flatten)]
pub wine: WineFeatures,
}
#[derive(Debug, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionUrlEvent {
pub body: Option<String>,
#[serde(default)]
pub is_base64_encoded: bool,
}
#[derive(Debug, serde::Serialize)]
pub struct Response {
pub quality: f32,
pub category: String,
pub confidence: f32,
pub top_factors: Vec<String>,
}
struct WinePredictor {
weights: Vec<f32>,
bias: f32,
}
impl WinePredictor {
fn new() -> Self {
Self {
weights: vec![
-0.05, -0.85, 0.45, 0.02, -0.15, 0.08, -0.12, -0.30, -0.10, 0.55, 0.95, ],
bias: 5.5,
}
}
fn predict(&self, features: &WineFeatures) -> Response {
let normalized = features.normalize();
let mut score = self.bias;
for (w, x) in self.weights.iter().zip(normalized.iter()) {
score += w * x;
}
let quality = score.clamp(0.0, 10.0);
let category = match quality {
q if q < 5.0 => "Poor",
q if q < 7.0 => "Average",
q if q < 8.0 => "Good",
_ => "Excellent",
}
.to_string();
let confidence = normalized
.iter()
.filter(|&&x| (0.0..=1.0).contains(&x))
.count() as f32
/ normalized.len() as f32;
let feature_names = [
"fixed_acidity",
"volatile_acidity",
"citric_acid",
"residual_sugar",
"chlorides",
"free_sulfur_dioxide",
"total_sulfur_dioxide",
"density",
"ph",
"sulphates",
"alcohol",
];
let mut importance: Vec<_> = feature_names
.iter()
.zip(self.weights.iter().zip(normalized.iter()))
.map(|(name, (w, x))| (*name, (w * x).abs()))
.collect();
importance.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let top_factors = importance
.iter()
.take(3)
.map(|(name, _)| name.to_string())
.collect();
Response {
quality,
category,
confidence,
top_factors,
}
}
}
static PREDICTOR: std::sync::OnceLock<WinePredictor> = std::sync::OnceLock::new();
fn get_predictor() -> &'static WinePredictor {
PREDICTOR.get_or_init(WinePredictor::new)
}
fn parse_and_predict(raw_body: &str, predictor: &WinePredictor) -> String {
if let Ok(event) = serde_json::from_str::<FunctionUrlEvent>(raw_body) {
if let Some(body) = event.body {
let decoded = if event.is_base64_encoded {
let bytes = base64_decode(&body).unwrap_or_default();
String::from_utf8(bytes).unwrap_or_default()
} else {
body
};
return match serde_json::from_str::<Request>(&decoded) {
Ok(req) => {
let response = predictor.predict(&req.wine);
serde_json::to_string(&response)
.unwrap_or_else(|_| r#"{"error":"serialization failed"}"#.to_string())
},
Err(e) => format!(r#"{{"error":"Invalid request body: {e}"}}"#),
};
}
}
match serde_json::from_str::<Request>(raw_body) {
Ok(req) => {
let response = predictor.predict(&req.wine);
serde_json::to_string(&response)
.unwrap_or_else(|_| r#"{"error":"serialization failed"}"#.to_string())
},
Err(e) => format!(r#"{{"error":"Invalid request: {e}"}}"#),
}
}
fn base64_decode(input: &str) -> Option<Vec<u8>> {
const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut output = Vec::new();
let mut buffer = 0u32;
let mut bits = 0;
for byte in input.bytes() {
if byte == b'=' {
break;
}
let value = ALPHABET.iter().position(|&c| c == byte)? as u32;
buffer = (buffer << 6) | value;
bits += 6;
if bits >= 8 {
bits -= 8;
output.push((buffer >> bits) as u8);
buffer &= (1 << bits) - 1;
}
}
Some(output)
}
fn main() {
if std::env::var("AWS_LAMBDA_RUNTIME_API").is_ok() {
lambda_runtime();
} else {
local_demo();
}
}
fn local_demo() {
println!("Wine Quality Predictor (Local Demo)\n");
let predictor = get_predictor();
let wines = vec![
(
"Premium Bordeaux",
WineFeatures {
fixed_acidity: 7.4,
volatile_acidity: 0.28,
citric_acid: 0.45,
residual_sugar: 2.1,
chlorides: 0.076,
free_sulfur_dioxide: 15.0,
total_sulfur_dioxide: 46.0,
density: 0.9958,
ph: 3.35,
sulphates: 0.68,
alcohol: 12.8,
},
),
(
"Budget Table Wine",
WineFeatures {
fixed_acidity: 8.5,
volatile_acidity: 0.72,
citric_acid: 0.12,
residual_sugar: 3.8,
chlorides: 0.092,
free_sulfur_dioxide: 8.0,
total_sulfur_dioxide: 28.0,
density: 0.9972,
ph: 3.42,
sulphates: 0.48,
alcohol: 10.2,
},
),
];
for (name, features) in wines {
let response = predictor.predict(&features);
println!("{name}:");
println!(
" Quality: {:.2}/10 ({})",
response.quality, response.category
);
println!(" Confidence: {:.0}%", response.confidence * 100.0);
println!(" Top factors: {}\n", response.top_factors.join(", "));
}
println!("Example JSON request:");
let example = WineFeatures {
fixed_acidity: 7.0,
volatile_acidity: 0.3,
citric_acid: 0.4,
residual_sugar: 2.0,
chlorides: 0.08,
free_sulfur_dioxide: 15.0,
total_sulfur_dioxide: 40.0,
density: 0.995,
ph: 3.3,
sulphates: 0.6,
alcohol: 12.0,
};
println!(
"{}",
serde_json::to_string_pretty(&example).unwrap_or_default()
);
}
fn lambda_runtime() {
let runtime_api =
std::env::var("AWS_LAMBDA_RUNTIME_API").expect("AWS_LAMBDA_RUNTIME_API not set");
let client = ureq::agent();
let predictor = get_predictor();
loop {
let next_url = format!("http://{runtime_api}/2018-06-01/runtime/invocation/next");
let resp = match client.get(&next_url).call() {
Ok(r) => r,
Err(e) => {
eprintln!("Failed to get invocation: {e}");
continue;
},
};
let request_id = resp
.header("Lambda-Runtime-Aws-Request-Id")
.unwrap_or("unknown")
.to_string();
let body = resp.into_string().unwrap_or_default();
let result = parse_and_predict(&body, predictor);
let response_url =
format!("http://{runtime_api}/2018-06-01/runtime/invocation/{request_id}/response");
if let Err(e) = client
.post(&response_url)
.set("Content-Type", "application/json")
.send_string(&result)
{
eprintln!("Failed to send response: {e}");
}
}
}