use std::any::Any;
use std::error::Error as StdError;
use std::io;
use std::pin::pin;
use bytes::{Buf, Bytes};
use headers::HeaderMapExt;
use http::StatusCode as SC;
use http::{self, Request, Response};
use http_body::Body as HttpBody;
use http_body_util::BodyExt;
use crate::body::Body;
use crate::conditional::if_match_get_tokens;
use crate::davheaders;
use crate::fs::*;
use crate::{DavError, DavInner, DavResult};
const SABRE: &str = "application/x-sabredav-partialupdate";
fn to_ioerror<E>(err: E) -> io::Error
where
E: StdError + Sync + Send + 'static,
{
let e = &err as &dyn Any;
if e.is::<io::Error>() || e.is::<Box<io::Error>>() {
let err = Box::new(err) as Box<dyn Any>;
match err.downcast::<io::Error>() {
Ok(e) => *e,
Err(e) => match e.downcast::<Box<io::Error>>() {
Ok(e) => *(*e),
Err(_) => io::ErrorKind::Other.into(),
},
}
} else if e.is::<DavError>() || e.is::<Box<DavError>>() {
let err = Box::new(err) as Box<dyn Any>;
match err.downcast::<DavError>() {
Ok(e) => (*e).into(),
Err(e) => match e.downcast::<Box<DavError>>() {
Ok(e) => (*(*e)).into(),
Err(_) => io::ErrorKind::Other.into(),
},
}
} else {
io::Error::other(err)
}
}
impl<C: Clone + Send + Sync + 'static> DavInner<C> {
pub(crate) async fn handle_put<ReqBody, ReqData, ReqError>(
self,
req: &Request<()>,
body: ReqBody,
) -> DavResult<Response<Body>>
where
ReqBody: HttpBody<Data = ReqData, Error = ReqError>,
ReqData: Buf + Send + 'static,
ReqError: StdError + Send + Sync + 'static,
{
let mut start = 0;
let mut count = 0;
let mut have_count = false;
let mut do_range = false;
let mut oo = OpenOptions::write();
oo.create = true;
oo.truncate = true;
if let Some(n) = req.headers().typed_get::<headers::ContentLength>() {
count = n.0;
have_count = true;
oo.size = Some(count);
} else if let Some(n) = req
.headers()
.get("X-Expected-Entity-Length")
.and_then(|v| v.to_str().ok())
{
if let Ok(len) = n.parse() {
count = len;
have_count = true;
oo.size = Some(count);
}
}
let checksum = req
.headers()
.get("OC-Checksum")
.and_then(|v| v.to_str().ok().map(|s| s.to_string()));
oo.checksum = checksum;
let path = self.path(req);
let meta = self.fs.metadata(&path, &self.credentials).await;
let mut res = Response::new(Body::empty());
res.headers_mut().typed_insert(headers::Connection::close());
if req.method() == http::Method::PATCH {
if req
.headers()
.typed_get::<davheaders::ContentType>()
.is_none_or(|ct| ct.0 != SABRE)
{
return Err(DavError::StatusClose(SC::UNSUPPORTED_MEDIA_TYPE));
}
if !have_count {
return Err(DavError::StatusClose(SC::LENGTH_REQUIRED));
};
let r = req
.headers()
.typed_get::<davheaders::XUpdateRange>()
.ok_or(DavError::StatusClose(SC::BAD_REQUEST))?;
match r {
davheaders::XUpdateRange::FromTo(b, e) => {
if b > e || e - b + 1 != count {
return Err(DavError::StatusClose(SC::RANGE_NOT_SATISFIABLE));
}
start = b;
}
davheaders::XUpdateRange::AllFrom(b) => {
start = b;
}
davheaders::XUpdateRange::Last(n) => {
if let Ok(ref m) = meta {
if n > m.len() {
return Err(DavError::StatusClose(SC::RANGE_NOT_SATISFIABLE));
}
start = m.len() - n;
}
}
davheaders::XUpdateRange::Append => {
oo.append = true;
}
}
do_range = true;
oo.truncate = false;
}
match req.headers().typed_try_get::<headers::ContentRange>() {
Ok(Some(range)) => {
if let Some((b, e)) = range.bytes_range() {
if b > e {
return Err(DavError::StatusClose(SC::RANGE_NOT_SATISFIABLE));
}
if have_count {
if e - b + 1 != count {
return Err(DavError::StatusClose(SC::RANGE_NOT_SATISFIABLE));
}
} else {
count = e - b + 1;
have_count = true;
}
start = b;
do_range = true;
oo.truncate = false;
}
}
Ok(None) => {}
Err(_) => return Err(DavError::StatusClose(SC::BAD_REQUEST)),
}
let tokens = if_match_get_tokens(
req,
meta.as_ref().map(|v| v.as_ref()).ok(),
self.fs.as_ref(),
&self.ls,
&path,
&self.credentials,
);
let tokens = match tokens.await {
Ok(t) => t,
Err(s) => return Err(DavError::StatusClose(s)),
};
if let Some(ref locksystem) = self.ls {
let principal = self.principal.as_deref();
if let Err(_l) = locksystem
.check(&path, principal, false, false, &tokens)
.await
{
return Err(DavError::StatusClose(SC::LOCKED));
}
}
if req
.headers()
.typed_get::<davheaders::IfMatch>()
.is_some_and(|h| h.0 == davheaders::ETagList::Star)
{
oo.create = false;
}
if req
.headers()
.typed_get::<davheaders::IfNoneMatch>()
.is_some_and(|h| h.0 == davheaders::ETagList::Star)
{
oo.create_new = true;
}
let create = oo.create;
let create_new = oo.create_new;
let mut file = match self.fs.open(&path, oo, &self.credentials).await {
Ok(f) => f,
Err(FsError::NotFound) | Err(FsError::Exists) => {
let s = if !create || create_new {
SC::PRECONDITION_FAILED
} else {
SC::CONFLICT
};
return Err(DavError::StatusClose(s));
}
Err(e) => return Err(DavError::FsError(e)),
};
if do_range {
if file.seek(std::io::SeekFrom::Start(start)).await.is_err() {
return Err(DavError::StatusClose(SC::RANGE_NOT_SATISFIABLE));
}
}
res.headers_mut()
.typed_insert(headers::AcceptRanges::bytes());
let mut body = pin!(body);
let mut total = 0u64;
while let Some(data) = body.frame().await {
let data_frame = data.map_err(|e| to_ioerror(e))?;
let Ok(mut buf) = data_frame.into_data() else {
continue;
};
total += buf.remaining() as u64;
if have_count && total > count {
break;
}
let b = {
let b: &mut dyn std::any::Any = &mut buf;
b.downcast_mut::<Bytes>()
};
if let Some(bytes) = b {
let bytes = std::mem::replace(bytes, Bytes::new());
file.write_bytes(bytes).await?;
} else {
file.write_buf(Box::new(buf)).await?;
}
}
file.flush().await?;
if have_count && total > count {
error!("PUT file: sender is sending more bytes than expected");
return Err(DavError::StatusClose(SC::BAD_REQUEST));
}
if have_count && total < count {
error!("PUT file: premature EOF on input");
return Err(DavError::StatusClose(SC::BAD_REQUEST));
}
*res.status_mut() = match meta {
Ok(_) => SC::NO_CONTENT,
Err(_) => {
res.headers_mut().typed_insert(headers::ContentLength(0));
SC::CREATED
}
};
res.headers_mut().remove(http::header::CONNECTION);
if let Ok(meta) = file.metadata().await {
if let Some(etag) = davheaders::ETag::from_meta(meta.as_ref()) {
res.headers_mut().typed_insert(etag);
}
if let Ok(modified) = meta.modified() {
res.headers_mut()
.typed_insert(headers::LastModified::from(modified));
}
}
Ok(res)
}
}