use axum::{
extract::{Query, State},
http::StatusCode,
Json,
};
use serde::{Deserialize, Serialize};
use crate::sharding::shard_for;
use crate::state::AppState;
#[derive(Debug, Deserialize)]
pub struct ShardInfoQuery {
pub execution_id: String,
pub shard_count: u32,
}
#[derive(Debug, Serialize)]
pub struct ServerShardConfig {
pub shard_index: u32,
pub shard_count: u32,
}
#[derive(Debug, Serialize)]
pub struct ShardInfoResponse {
pub execution_id: i64,
pub shard_count: u32,
pub shard_index: u32,
pub source: &'static str,
pub hash_function: &'static str,
pub seed: u64,
pub server_config: ServerShardConfig,
}
pub async fn get_shard_info(
State(state): State<AppState>,
Query(params): Query<ShardInfoQuery>,
) -> Result<Json<ShardInfoResponse>, (StatusCode, Json<serde_json::Value>)> {
let execution_id: i64 = params.execution_id.parse().map_err(|_| {
(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": format!("execution_id {:?} is not a valid i64", params.execution_id),
})),
)
})?;
if params.shard_count == 0 {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": "shard_count must be >= 1",
})),
));
}
if params.shard_count > 1024 {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": format!(
"shard_count {} exceeds practical maximum 1024",
params.shard_count
),
})),
));
}
let shard_index = shard_for(execution_id, params.shard_count);
Ok(Json(ShardInfoResponse {
execution_id,
shard_count: params.shard_count,
shard_index,
source: "noetl-server",
hash_function: "twox_hash::XxHash64",
seed: 0,
server_config: ServerShardConfig {
shard_index: state.shard.shard_index,
shard_count: state.shard.shard_count,
},
}))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn shard_index_matches_pinned_values_n_16() {
let n = 16;
let cases: &[i64] = &[
1,
42,
320_816_801_799_737_344,
i64::MAX,
-1,
];
for eid in cases {
let s = shard_for(*eid, n);
assert!(s < n, "shard {s} out of range for eid={eid}, n={n}");
}
}
#[test]
fn execution_id_parses_as_i64() {
for s in &["1", "42", "-1", "9999999999", "320816801799737344"] {
let parsed: i64 = s.parse().expect("valid i64");
assert_eq!(parsed.to_string(), *s);
}
}
#[test]
fn execution_id_rejects_non_numeric() {
for s in &["abc", "12.5", "0x10", "", "1_000"] {
assert!(
s.parse::<i64>().is_err(),
"expected {s:?} to fail i64 parse",
);
}
}
}