use axum::http::{HeaderMap, header};
use bytes::Bytes;
use prost::Message;
use serde::Deserialize;
use super::proto;
use super::response::is_binary_protobuf;
use crate::Error;
fn is_protobuf_content(headers: &HeaderMap) -> bool {
headers
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(is_binary_protobuf)
.unwrap_or(false)
}
#[derive(Debug, Deserialize)]
pub struct ScanParams {
pub key: String,
pub start_seq: Option<u64>,
pub end_seq: Option<u64>,
pub limit: Option<usize>,
pub follow: Option<bool>,
pub timeout_ms: Option<u64>,
}
impl ScanParams {
pub fn key(&self) -> Bytes {
Bytes::from(self.key.clone())
}
pub fn seq_range(&self) -> std::ops::Range<u64> {
let start = self.start_seq.unwrap_or(0);
let end = self.end_seq.unwrap_or(u64::MAX);
start..end
}
}
#[derive(Debug, Deserialize)]
pub struct ListKeysParams {
pub start_segment: Option<u32>,
pub end_segment: Option<u32>,
pub limit: Option<usize>,
}
impl ListKeysParams {
pub fn segment_range(&self) -> std::ops::Range<u32> {
let start = self.start_segment.unwrap_or(0);
let end = self.end_segment.unwrap_or(u32::MAX);
start..end
}
}
#[derive(Debug, Deserialize)]
pub struct ListSegmentsParams {
pub start_seq: Option<u64>,
pub end_seq: Option<u64>,
}
impl ListSegmentsParams {
pub fn seq_range(&self) -> std::ops::Range<u64> {
let start = self.start_seq.unwrap_or(0);
let end = self.end_seq.unwrap_or(u64::MAX);
start..end
}
}
#[derive(Debug, Deserialize)]
pub struct CountParams {
pub key: String,
pub start_seq: Option<u64>,
pub end_seq: Option<u64>,
}
impl CountParams {
pub fn key(&self) -> Bytes {
Bytes::from(self.key.clone())
}
pub fn seq_range(&self) -> std::ops::Range<u64> {
let start = self.start_seq.unwrap_or(0);
let end = self.end_seq.unwrap_or(u64::MAX);
start..end
}
}
#[derive(Debug)]
pub struct AppendRequest {
pub records: Vec<crate::Record>,
pub await_durable: bool,
}
impl AppendRequest {
pub fn from_body(headers: &HeaderMap, body: &[u8]) -> Result<Self, Error> {
if is_protobuf_content(headers) {
Self::from_protobuf(body)
} else {
Self::from_json(body)
}
}
fn from_protobuf(body: &[u8]) -> Result<Self, Error> {
let proto_request = proto::AppendRequest::decode(body)
.map_err(|e| Error::InvalidInput(format!("Invalid protobuf: {}", e)))?;
Self::from_proto_request(proto_request)
}
fn from_json(body: &[u8]) -> Result<Self, Error> {
let proto_request: proto::AppendRequest = serde_json::from_slice(body)
.map_err(|e| Error::InvalidInput(format!("Invalid JSON: {}", e)))?;
Self::from_proto_request(proto_request)
}
fn from_proto_request(proto_request: proto::AppendRequest) -> Result<Self, Error> {
let mut records = Vec::with_capacity(proto_request.records.len());
for (i, r) in proto_request.records.into_iter().enumerate() {
let key = r
.key
.ok_or_else(|| Error::InvalidInput(format!("record[{}]: key is required", i)))?;
let value = r
.value
.ok_or_else(|| Error::InvalidInput(format!("record[{}]: value is required", i)))?;
records.push(crate::Record { key, value });
}
Ok(Self {
records,
await_durable: proto_request.await_durable,
})
}
}
#[cfg(test)]
mod tests {
use axum::http::HeaderValue;
use super::*;
fn protobuf_headers() -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/protobuf"),
);
headers
}
fn protojson_headers() -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/protobuf+json"),
);
headers
}
#[test]
fn should_get_key_as_bytes() {
let params = ScanParams {
key: "my-key".to_string(),
start_seq: None,
end_seq: None,
limit: None,
follow: None,
timeout_ms: None,
};
let key = params.key();
assert_eq!(key.as_ref(), b"my-key");
}
#[test]
fn should_parse_append_request_from_json() {
let json = br#"{
"records": [{"key": "dGVzdC1rZXk=", "value": "dGVzdC12YWx1ZQ=="}],
"awaitDurable": true
}"#;
let request = AppendRequest::from_body(&protojson_headers(), json).unwrap();
assert_eq!(request.records.len(), 1);
assert_eq!(request.records[0].key, Bytes::from("test-key"));
assert_eq!(request.records[0].value, Bytes::from("test-value"));
assert!(request.await_durable);
}
#[test]
fn should_parse_append_request_from_json_without_await_durable() {
let json = br#"{
"records": [{"key": "a2V5", "value": "dmFsdWU="}]
}"#;
let request = AppendRequest::from_body(&protojson_headers(), json).unwrap();
assert!(!request.await_durable);
}
#[test]
fn should_parse_append_request_from_protobuf() {
let proto_request = proto::AppendRequest {
records: vec![proto::Record {
key: Some(Bytes::from("proto-key")),
value: Some(Bytes::from("proto-value")),
}],
await_durable: true,
};
let body = proto_request.encode_to_vec();
let request = AppendRequest::from_body(&protobuf_headers(), &body).unwrap();
assert_eq!(request.records.len(), 1);
assert_eq!(request.records[0].key, Bytes::from("proto-key"));
assert_eq!(request.records[0].value, Bytes::from("proto-value"));
assert!(request.await_durable);
}
#[test]
fn should_return_error_for_missing_key() {
let json = br#"{
"records": [{"value": "dmFsdWU="}]
}"#;
let result = AppendRequest::from_body(&protojson_headers(), json);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("key is required"));
}
#[test]
fn should_return_error_for_missing_value() {
let json = br#"{
"records": [{"key": "a2V5"}]
}"#;
let result = AppendRequest::from_body(&protojson_headers(), json);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("value is required")
);
}
#[test]
fn should_return_error_for_invalid_json() {
let body = b"not valid json";
let result = AppendRequest::from_body(&protojson_headers(), body);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Invalid JSON"));
}
#[test]
fn should_return_error_for_invalid_protobuf() {
let body = &[0xFF, 0xFF, 0xFF];
let result = AppendRequest::from_body(&protobuf_headers(), body);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Invalid protobuf"));
}
#[test]
fn should_return_default_seq_range() {
let params = ScanParams {
key: "test".to_string(),
start_seq: None,
end_seq: None,
limit: None,
follow: None,
timeout_ms: None,
};
let range = params.seq_range();
assert_eq!(range.start, 0);
assert_eq!(range.end, u64::MAX);
}
#[test]
fn should_use_provided_seq_range() {
let params = ScanParams {
key: "test".to_string(),
start_seq: Some(10),
end_seq: Some(100),
limit: None,
follow: None,
timeout_ms: None,
};
let range = params.seq_range();
assert_eq!(range.start, 10);
assert_eq!(range.end, 100);
}
#[test]
fn should_default_follow_to_false() {
let json = r#"{"key": "test"}"#;
let params: ScanParams = serde_json::from_str(json).unwrap();
assert_eq!(params.follow, None);
assert!(params.timeout_ms.is_none());
}
#[test]
fn should_parse_follow_and_timeout_params() {
let json = r#"{"key": "test", "follow": true, "timeout_ms": 5000}"#;
let params: ScanParams = serde_json::from_str(json).unwrap();
assert_eq!(params.follow, Some(true));
assert_eq!(params.timeout_ms, Some(5000));
}
}