use crate::core::error::Result;
use crate::core::protocol;
use crate::core::protocol::constants::{headers, media_types};
use crate::core::types::{Update, Version};
use axum::{
body::Body,
http::{header, HeaderValue, StatusCode},
response::{IntoResponse, Response},
};
use bytes::Bytes;
use std::collections::BTreeMap;
pub trait SendUpdateExt {
fn send_update(&mut self, update: &Update) -> Result<()>;
fn send_body(&mut self, body: &[u8]) -> Result<()>;
}
pub struct UpdateResponse {
status: u16,
headers: BTreeMap<String, String>,
body: Option<Bytes>,
}
impl UpdateResponse {
pub fn new(status: u16) -> Self {
UpdateResponse {
status,
headers: BTreeMap::new(),
body: None,
}
}
pub fn with_version(mut self, versions: Vec<Version>) -> Self {
let version_str = protocol::format_version_header(&versions);
self.headers
.insert(headers::VERSION.as_str().to_string(), version_str);
self
}
pub fn with_parents(mut self, parents: Vec<Version>) -> Self {
let parents_str = protocol::format_version_header(&parents);
self.headers
.insert(headers::PARENTS.as_str().to_string(), parents_str);
self
}
pub fn with_current_version(mut self, versions: Vec<Version>) -> Self {
let current_version_str = protocol::format_version_header(&versions);
self.headers.insert(
headers::CURRENT_VERSION.as_str().to_string(),
current_version_str,
);
self
}
pub fn with_body(mut self, body: impl Into<Bytes>) -> Self {
self.body = Some(body.into());
self
}
pub fn with_header(mut self, key: String, value: String) -> Self {
self.headers.insert(key, value);
self
}
pub fn build(self) -> Response {
let mut response = match self.status {
200 => Response::builder().status(StatusCode::OK),
209 => Response::builder().status(StatusCode::from_u16(209).unwrap()),
404 => Response::builder().status(StatusCode::NOT_FOUND),
500 => Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR),
_ => Response::builder().status(StatusCode::from_u16(self.status).unwrap()),
};
for (key, value) in &self.headers {
if let Ok(header_value) = value.parse::<HeaderValue>() {
response = response.header(key, header_value);
}
}
if let Some(body) = self.body {
response
.header(header::CONTENT_LENGTH, body.len())
.body(Body::from(body))
.unwrap_or_else(|_| Response::default())
} else {
response
.body(Body::empty())
.unwrap_or_else(|_| Response::default())
}
}
}
impl IntoResponse for Update {
fn into_response(self) -> Response {
let mut response_builder = UpdateResponse::new(self.status);
if !self.version.is_empty() {
response_builder = response_builder.with_version(self.version.clone());
}
if !self.parents.is_empty() {
response_builder = response_builder.with_parents(self.parents.clone());
}
if let Some(current_version) = &self.current_version {
response_builder = response_builder.with_current_version(current_version.clone());
}
if let Some(content_type) = &self.content_type {
response_builder = response_builder.with_header(
header::CONTENT_TYPE.as_str().to_string(),
content_type.clone(),
);
} else if self.patches.is_some() {
response_builder = response_builder.with_header(
header::CONTENT_TYPE.as_str().to_string(),
media_types::BRAID_PATCH.to_string(),
);
}
for (key, value) in &self.extra_headers {
response_builder = response_builder.with_header(key.clone(), value.clone());
}
if let Some(body) = &self.body {
response_builder = response_builder.with_body(body.clone());
} else if let Some(patches) = &self.patches {
let patches_str = patches.len().to_string();
response_builder =
response_builder.with_header(headers::PATCHES.as_str().to_string(), patches_str);
if patches.len() == 1 {
let patch = &patches[0];
let content_range = format!("{} {}", patch.unit, patch.range);
response_builder = response_builder
.with_header(headers::CONTENT_RANGE.as_str().to_string(), content_range);
response_builder = response_builder.with_body(patch.content.clone());
} else if patches.len() > 1 {
let mut multi_body = bytes::BytesMut::new();
for patch in patches {
use bytes::BufMut;
let patch_headers = format!(
"Content-Length: {}\r\nContent-Range: {} {}\r\n\r\n",
patch.len(),
patch.unit,
patch.range
);
multi_body.put_slice(patch_headers.as_bytes());
multi_body.put_slice(&patch.content);
multi_body.put_slice(b"\r\n");
}
response_builder = response_builder.with_body(multi_body.freeze());
}
}
response_builder.build()
}
}
pub mod status {
use axum::http::StatusCode;
#[allow(dead_code)]
pub const SUBSCRIPTION: u16 = 209;
#[allow(dead_code)]
pub const RESPONDED_VIA_MULTIPLEX: u16 = 293;
#[allow(dead_code)]
pub fn subscription_response() -> StatusCode {
StatusCode::from_u16(SUBSCRIPTION).unwrap()
}
#[allow(dead_code)]
pub fn multiplex_response() -> StatusCode {
StatusCode::from_u16(RESPONDED_VIA_MULTIPLEX).unwrap()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_update_response_builder() {
let response = UpdateResponse::new(200)
.with_version(vec![Version::from("v1")])
.with_header("Custom".to_string(), "value".to_string())
.build();
assert_eq!(response.status(), StatusCode::OK);
}
#[test]
fn test_version_in_response() {
let response = UpdateResponse::new(200)
.with_version(vec![Version::from("v42")])
.build();
let version_header = response
.headers()
.get("version")
.and_then(|v| v.to_str().ok());
assert!(version_header.is_some());
assert!(version_header.unwrap().contains("v42"));
}
#[test]
fn test_version_with_parents_in_response() {
let response = UpdateResponse::new(200)
.with_version(vec![Version::from("v2")])
.with_parents(vec![Version::from("v1")])
.build();
assert!(response.headers().contains_key("version"));
assert!(response.headers().contains_key("parents"));
}
#[test]
fn test_patch_content_type() {
use crate::core::types::Patch;
let update = Update::patched(Version::from("v1"), vec![Patch::json(".a", "1")]);
let response: Response = update.into_response();
let ct = response
.headers()
.get("content-type")
.unwrap()
.to_str()
.unwrap();
assert_eq!(ct, "application/braid-patch");
}
#[test]
fn test_subscription_status() {
let update = Update::subscription_snapshot(Version::from("v1"), "data");
let response: Response = update.into_response();
assert_eq!(response.status().as_u16(), 209);
}
}