use std::{error, fmt, io};
use base64::{engine::general_purpose, Engine as _};
use crate::shared::EncryptionKey;
use super::{
CacheWrapper, RequestPublicKey, RequestSigningPublicKey
};
use crate::shared::NCRYPTF_CONTENT_TYPE;
use anyhow::anyhow;
use rocket::{
data::{FromData, Limits, Outcome},
http::{ContentType, Header, Status},
response::{self, Responder, Response},
Data, Request, State,
};
use serde::{Deserialize, Serialize};
#[derive(Debug)]
pub enum Error<'a> {
Io(io::Error),
Parse(&'a str, serde_json::error::Error),
}
impl<'a> fmt::Display for Error<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Io(err) => write!(f, "i/o error: {}", err),
Self::Parse(_, err) => write!(f, "parse error: {}", err),
}
}
}
impl<'a> error::Error for Error<'a> {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match self {
Self::Io(err) => Some(err),
Self::Parse(_, err) => Some(err),
}
}
}
#[derive(Debug, Clone)]
pub struct Json<T>(pub T);
impl<T> Json<T> {
pub fn into_inner(self) -> T {
return self.0;
}
pub fn from_value(value: T) -> Self {
return Self(value);
}
pub fn deserialize_req_from_string<'r>(
req: &'r Request<'_>,
string: String,
cache: &CacheWrapper,
) -> Result<String, Error<'r>> {
match req.headers().get_one("Content-Type") {
Some(h) => {
match h {
NCRYPTF_CONTENT_TYPE => {
let data = general_purpose::STANDARD.decode(string).unwrap();
let hash_id = match req.headers().get_one("X-HashId") {
Some(h) => h,
None => {
return Err(Error::Io(io::Error::new(
io::ErrorKind::Other,
"Missing client provided hash identifier.",
)));
}
};
let ek = match cache.get(hash_id) {
Some(ek) => ek,
None => {
return Err(Error::Io(io::Error::new(
io::ErrorKind::Other,
"Encryption key is either invalid, or may have expired.",
)));
}
};
let sk = ek.get_box_kp().get_secret_key();
if ek.is_ephemeral() {
cache.remove(hash_id);
}
match crate::Response::from(sk) {
Ok(response) => {
match crate::Response::get_public_key_from_response(data.clone()) {
Ok(cpk) => {
req.local_cache(|| {
return RequestPublicKey(cpk.clone());
});
}
Err(error) => {
return Err(Error::Io(io::Error::new(
io::ErrorKind::Other,
error.to_string(),
)));
}
};
match crate::Response::get_signing_public_key_from_response(
data.clone(),
) {
Ok(cpk) => {
req.local_cache(|| {
return RequestSigningPublicKey(cpk.clone());
});
}
Err(error) => {
return Err(Error::Io(io::Error::new(
io::ErrorKind::Other,
error.to_string(),
)));
}
};
let public_key = match req.headers().get_one("X-PubKey") {
Some(h) => Some(h.as_bytes().to_vec()),
None => None,
};
let nonce = match req.headers().get_one("X-Nonce") {
Some(h) => Some(h.as_bytes().to_vec()),
None => None,
};
match response.decrypt(data.clone(), public_key, nonce) {
Ok(msg) => {
return Ok(req.local_cache(|| return msg).to_owned());
}
Err(error) => {
return Err(Error::Io(io::Error::new(
io::ErrorKind::Other,
error.to_string(),
)));
}
};
}
Err(error) => {
return Err(Error::Io(io::Error::new(
io::ErrorKind::Other,
error.to_string(),
)));
}
};
}
"json" => {
return Ok(req.local_cache(|| return string).to_owned());
}
_ => {
return Ok(req.local_cache(|| return string).to_owned());
}
}
}
None => {
return Ok(req.local_cache(|| return string).to_owned());
}
};
}
}
impl<'r, T: Deserialize<'r>> Json<T> {
pub fn from_str(s: &'r str) -> Result<Self, Error<'r>> {
return serde_json::from_str(s)
.map(Json)
.map_err(|e| Error::Parse(s, e));
}
pub async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> Result<Self, Error<'r>> {
parse_body(req, data).await
}
}
#[rocket::async_trait]
impl<'r, T: Deserialize<'r>> FromData<'r> for Json<T> {
type Error = Error<'r>;
async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> Outcome<'r, Self> {
match Self::from_data(req, data).await {
Ok(value) => Outcome::Success(value),
Err(Error::Io(e)) if e.kind() == io::ErrorKind::UnexpectedEof => {
Outcome::Error((Status::PayloadTooLarge, Error::Io(e)))
}
Err(Error::Parse(s, e)) if e.classify() == serde_json::error::Category::Data => {
req.local_cache(|| return "".to_string());
Outcome::Error((Status::UnprocessableEntity, Error::Parse(s, e)))
}
Err(e) => Outcome::Error((Status::BadRequest, e)),
}
}
}
impl<'r, T: Serialize> Responder<'r, 'static> for Json<T> {
fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> {
match respond_to_with_ncryptf(self, Status::Ok, req) {
Ok(response) => response,
Err(_) => return Err(Status::InternalServerError),
}
}
}
#[derive(Debug, Clone)]
pub struct JsonResponse<T> {
pub status: Status,
pub json: Json<T>,
}
impl<'r, T: Serialize> Responder<'r, 'static> for JsonResponse<T> {
fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> {
return match respond_to_with_ncryptf(self.json, self.status, req) {
Ok(response) => response,
Err(_) => return Err(Status::InternalServerError),
};
}
}
pub async fn parse_body<'r, T: Deserialize<'r>>(req: &'r Request<'_>, data: Data<'r>) -> Result<Json<T>, Error<'r>> {
let limit = req.limits().get("json").unwrap_or(Limits::JSON);
let string = match data.open(limit).into_string().await {
Ok(s) if s.is_complete() => s.into_inner(),
Ok(_) => {
let eof = io::ErrorKind::UnexpectedEof;
return Err(Error::Io(io::Error::new(eof, "data limit exceeded")));
}
Err(error) => return Err(Error::Io(error)),
};
let cache = match req.rocket().state::<CacheWrapper>() {
Some(cache) => cache,
None => {
return Err(Error::Io(io::Error::new(
io::ErrorKind::Other,
"Cache not found in managed state",
)));
}
};
match Json::<T>::deserialize_req_from_string(req, string, cache) {
Ok(s) => {
return Json::<T>::from_str(req.local_cache(|| return s));
}
Err(error) => return Err(error),
};
}
pub fn respond_to_with_ncryptf<'r, 'a, T: serde::Serialize>(
m: Json<T>,
status: Status,
req: &'r Request<'_>,
) -> Result<response::Result<'static>, anyhow::Error> {
let message = match serde_json::to_string(&m.0) {
Ok(json) => json,
Err(_error) => return Err(anyhow!("Could not deserialize message")),
};
match req.headers().get_one("Accept") {
Some(accept) => {
match accept {
NCRYPTF_CONTENT_TYPE => {
let cpk = req.local_cache(|| {
return RequestPublicKey(Vec::<u8>::new());
});
let pk: Vec<u8>;
if cpk.0.is_empty() {
pk = match req.headers().get_one("X-PubKey") {
Some(h) => general_purpose::STANDARD.decode(h).unwrap(),
None =>return Err(anyhow!("Public key is not available on request. Unable to re-encrypt message to client."))
};
} else {
pk = cpk.0.clone();
}
let ek = EncryptionKey::new(false);
let cache = match req.rocket().state::<CacheWrapper>() {
Some(cache) => cache,
None => return Err(anyhow!("Cache not found in managed state")),
};
cache.set(ek.get_hash_id(), ek.clone());
let mut request = match crate::Request::from(
ek.get_box_kp().get_secret_key(),
ek.get_sign_kp().get_secret_key(),
) {
Ok(request) => request,
Err(_error) => return Err(anyhow!("Unable to encrypt message")),
};
let content = match request.encrypt(message, pk) {
Ok(content) => content,
Err(_error) => return Err(anyhow!("Unable to encrypt message")),
};
let d = general_purpose::STANDARD.encode(content);
let respond_to = match d.respond_to(req) {
Ok(s) => s,
Err(_) => return Err(anyhow!("Could not send response")),
};
return Ok(Response::build_from(respond_to)
.header(ContentType::new("application", "vnd.ncryptf+json"))
.header(Header::new(
"x-public-key",
general_purpose::STANDARD.encode(ek.get_box_kp().get_public_key()),
))
.header(Header::new(
"x-signature-public-key",
general_purpose::STANDARD.encode(ek.get_sign_kp().get_public_key()),
))
.header(Header::new(
"x-public-key-expiration",
ek.expires_at.to_string(),
))
.header(Header::new("x-hashid", ek.get_hash_id()))
.status(status)
.ok());
}
_ => {
let respond_to = match message.respond_to(req) {
Ok(s) => s,
Err(_) => return Err(anyhow!("Could not send response")),
};
return Ok(Response::build_from(respond_to)
.header(ContentType::new("application", "json"))
.status(status)
.ok());
}
}
}
None => {
let respond_to = match message.respond_to(req) {
Ok(s) => s,
Err(_) => return Err(anyhow!("Could not send response")),
};
return Ok(Response::build_from(respond_to)
.header(ContentType::new("application", "json"))
.status(status)
.ok());
}
}
}