use std::sync::Arc;
use futures::StreamExt;
use crate::auth::{AuthProvider, DefaultTransport, TransportConfig};
use crate::error::ClientError;
use crate::request::Session;
use crate::sse::{parse_sse_block, SseFrame};
struct SseStreamState<S> {
stream: S,
raw_buf: Vec<u8>,
buf: String,
scan_from: usize,
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ClientConfig {
pub request_timeout: std::time::Duration,
pub max_session_body: u64,
pub max_call_body: u64,
pub max_download_body: u64,
pub max_upload_response_body: u64,
pub max_sse_frame: usize,
pub max_ws_message: usize,
}
impl Default for ClientConfig {
fn default() -> Self {
ClientConfig {
request_timeout: std::time::Duration::from_secs(30),
max_session_body: 1024 * 1024,
max_call_body: 8 * 1024 * 1024,
max_download_body: 64 * 1024 * 1024,
max_upload_response_body: 1024 * 1024,
max_sse_frame: 1024 * 1024,
max_ws_message: 1024 * 1024,
}
}
}
impl ClientConfig {
pub fn validate(&self) -> Result<(), ClientError> {
if self.max_session_body == 0 {
return Err(ClientError::InvalidArgument(
"ClientConfig.max_session_body must be > 0".into(),
));
}
if self.max_call_body == 0 {
return Err(ClientError::InvalidArgument(
"ClientConfig.max_call_body must be > 0".into(),
));
}
if self.max_download_body == 0 {
return Err(ClientError::InvalidArgument(
"ClientConfig.max_download_body must be > 0".into(),
));
}
if self.max_upload_response_body == 0 {
return Err(ClientError::InvalidArgument(
"ClientConfig.max_upload_response_body must be > 0".into(),
));
}
if self.request_timeout == std::time::Duration::ZERO {
return Err(ClientError::InvalidArgument(
"ClientConfig.request_timeout must be > 0; use Duration::from_secs(30) or similar"
.into(),
));
}
if self.max_sse_frame == 0 {
return Err(ClientError::InvalidArgument(
"ClientConfig.max_sse_frame must be > 0".into(),
));
}
if self.max_ws_message == 0 {
return Err(ClientError::InvalidArgument(
"ClientConfig.max_ws_message must be > 0".into(),
));
}
Ok(())
}
}
#[derive(Clone)]
pub struct JmapClient {
pub(crate) base_url: url::Url,
pub(crate) auth: Arc<dyn AuthProvider>,
pub(crate) http: reqwest::Client,
pub(crate) config: ClientConfig,
}
impl std::fmt::Debug for JmapClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JmapClient")
.field("base_url", &self.base_url)
.field("config", &self.config)
.finish_non_exhaustive()
}
}
impl JmapClient {
pub fn new(
transport: impl TransportConfig,
auth: impl AuthProvider + 'static,
base_url: &str,
config: ClientConfig,
) -> Result<Self, ClientError> {
let parsed = parse_base_url(base_url)?;
config.validate()?;
let http = transport.build_client()?;
Ok(Self {
base_url: parsed,
auth: Arc::new(auth),
http,
config,
})
}
pub fn new_plain(
auth: impl AuthProvider + 'static,
base_url: &str,
config: ClientConfig,
) -> Result<Self, ClientError> {
Self::new(DefaultTransport, auth, base_url, config)
}
pub(crate) fn inject_auth(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
if let Some((name, value)) = self.auth.auth_header() {
builder.header(name, value)
} else {
builder
}
}
pub(crate) fn check_auth_status(status: reqwest::StatusCode) -> Result<(), ClientError> {
if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
Err(ClientError::AuthFailed(status.as_u16()))
} else {
Ok(())
}
}
pub async fn fetch_session(&self) -> Result<Session, ClientError> {
let limit = self.config.max_session_body;
let url = self.base_url.join(".well-known/jmap").map_err(|e| {
ClientError::InvalidSession(format!("cannot construct session URL: {e}"))
})?;
let req = self.inject_auth(self.http.get(url).timeout(self.config.request_timeout));
let resp = {
let raw_resp = req.send().await.map_err(ClientError::from_reqwest)?;
Self::check_auth_status(raw_resp.status())?;
raw_resp
.error_for_status()
.map_err(ClientError::from_reqwest)?
};
if let Some(len) = resp.content_length() {
if len > limit {
return Err(ClientError::ResponseTooLarge { actual: len, limit });
}
}
let bytes = resp.bytes().await.map_err(ClientError::from_reqwest)?;
if bytes.len() as u64 > limit {
return Err(ClientError::ResponseTooLarge {
actual: bytes.len() as u64,
limit,
});
}
let session: Session = serde_json::from_slice(&bytes).map_err(ClientError::Parse)?;
validate_session_url_schemes(&session)?;
Ok(session)
}
pub async fn call(
&self,
api_url: &str,
req: &jmap_types::JmapRequest,
) -> Result<jmap_types::JmapResponse, ClientError> {
require_http_url(api_url)?;
let limit = self.config.max_call_body;
let builder = self.inject_auth(
self.http
.post(api_url)
.json(req)
.timeout(self.config.request_timeout),
);
let resp = {
let raw_resp = builder.send().await.map_err(ClientError::from_reqwest)?;
Self::check_auth_status(raw_resp.status())?;
raw_resp
.error_for_status()
.map_err(ClientError::from_reqwest)?
};
if let Some(len) = resp.content_length() {
if len > limit {
return Err(ClientError::ResponseTooLarge { actual: len, limit });
}
}
let bytes = resp.bytes().await.map_err(ClientError::from_reqwest)?;
if bytes.len() as u64 > limit {
return Err(ClientError::ResponseTooLarge {
actual: bytes.len() as u64,
limit,
});
}
let jmap_resp: jmap_types::JmapResponse =
serde_json::from_slice(&bytes).map_err(ClientError::Parse)?;
Ok(jmap_resp)
}
pub async fn subscribe_events(
&self,
event_source_url: &str,
last_event_id: Option<&str>,
) -> Result<futures::stream::BoxStream<'static, Result<SseFrame, ClientError>>, ClientError>
{
require_http_url(event_source_url)?;
let mut req = self
.http
.get(event_source_url)
.header("Accept", "text/event-stream");
if let Some(id) = last_event_id {
req = req.header("Last-Event-ID", id);
}
let req = self.inject_auth(req);
let resp = req.send().await.map_err(ClientError::from_reqwest)?;
Self::check_auth_status(resp.status())?;
let resp = resp.error_for_status().map_err(ClientError::from_reqwest)?;
{
let ct = resp
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_ascii_lowercase();
let essence = ct
.split(|c: char| c == ';' || c.is_whitespace())
.next()
.unwrap_or("");
if essence != "text/event-stream" {
return Err(ClientError::UnexpectedResponse(format!(
"subscribe_events: expected Content-Type text/event-stream, got: {ct:?}"
)));
}
}
let byte_stream = resp.bytes_stream();
let sse_frame_limit = self.config.max_sse_frame;
Ok(futures::stream::unfold(
Some(SseStreamState {
stream: byte_stream,
raw_buf: Vec::new(),
buf: String::new(),
scan_from: 0, }),
move |state| async move {
let SseStreamState {
mut stream,
mut raw_buf,
mut buf,
mut scan_from,
} = state?;
loop {
let frame_end = [
buf[scan_from..]
.find("\r\n\r\n")
.map(|p| (scan_from + p, 4usize)),
buf[scan_from..]
.find("\n\n")
.map(|p| (scan_from + p, 2usize)),
buf[scan_from..]
.find("\r\r")
.map(|p| (scan_from + p, 2usize)),
buf[scan_from..]
.find("\n\r\n")
.map(|p| (scan_from + p, 3usize)),
]
.into_iter()
.flatten()
.min_by_key(|&(pos, _)| pos);
if let Some((pos, delim_len)) = frame_end {
let frame = {
let slice = &buf[..pos];
if slice.contains('\r') {
slice.replace("\r\n", "\n").replace('\r', "\n")
} else {
slice.to_owned()
}
};
buf.drain(..pos + delim_len);
scan_from = 0; let sse_frame = parse_sse_block(&frame);
return Some((
Ok(sse_frame),
Some(SseStreamState {
stream,
raw_buf,
buf,
scan_from,
}),
));
}
match stream.next().await {
None => return None,
Some(Err(e)) => {
return Some((Err(ClientError::from_reqwest(e)), None));
}
Some(Ok(bytes)) => {
raw_buf.extend_from_slice(&bytes);
if raw_buf.len() > sse_frame_limit {
return Some((
Err(ClientError::SseFrameTooLarge {
limit: sse_frame_limit,
}),
None,
));
}
let old_len = buf.len();
decode_utf8_chunk(&mut raw_buf, &mut buf);
scan_from = old_len.saturating_sub(3);
while scan_from > 0 && !buf.is_char_boundary(scan_from) {
scan_from -= 1;
}
if buf.len() > sse_frame_limit {
return Some((
Err(ClientError::SseFrameTooLarge {
limit: sse_frame_limit,
}),
None,
));
}
}
}
}
},
)
.boxed())
}
pub async fn connect_ws_session(
&self,
ws_url: &str,
auth_header: Option<(&str, &str)>,
) -> Result<crate::ws::WsSession, ClientError> {
crate::ws::connect_ws_with_limit(ws_url, auth_header, self.config.max_ws_message).await
}
}
pub fn extract_response<T: serde::de::DeserializeOwned>(
resp: &jmap_types::JmapResponse,
call_id: &str,
) -> Result<T, ClientError> {
let mut first_success: Option<&jmap_types::Invocation> = None;
for inv in resp.method_responses.iter().filter(|inv| inv.2 == call_id) {
if inv.0 == "error" {
let args = &inv.1;
let err_type = args
.get("type")
.and_then(|v| v.as_str())
.unwrap_or("serverError") .to_owned();
let description = args
.get("description")
.and_then(|v| v.as_str())
.map(str::to_owned);
return Err(ClientError::MethodError {
error_type: err_type,
description,
});
}
if first_success.is_none() {
first_success = Some(inv);
}
}
let inv = first_success.ok_or_else(|| ClientError::MethodNotFound(call_id.to_owned()))?;
<T as serde::Deserialize>::deserialize(&inv.1).map_err(ClientError::Parse)
}
fn decode_utf8_chunk(raw: &mut Vec<u8>, buf: &mut String) {
match std::str::from_utf8(raw) {
Ok(s) => {
buf.push_str(s);
raw.clear();
}
Err(e) => {
let valid_up_to = e.valid_up_to();
buf.push_str(
std::str::from_utf8(&raw[..valid_up_to])
.expect("valid_up_to is a valid UTF-8 boundary"),
);
match e.error_len() {
Some(n) => {
let drain_end = (valid_up_to + n).min(raw.len());
raw.drain(..drain_end);
}
None => {
raw.drain(..valid_up_to);
}
}
}
}
}
fn url_scheme(url: &str) -> Option<&str> {
url.split_once("://").map(|(scheme, _)| scheme)
}
fn is_http_or_https(url: &str) -> bool {
url_scheme(url)
.is_some_and(|s| s.eq_ignore_ascii_case("http") || s.eq_ignore_ascii_case("https"))
}
fn parse_base_url(base_url: &str) -> Result<url::Url, ClientError> {
if base_url.is_empty() {
return Err(ClientError::InvalidArgument(
"base_url may not be empty".into(),
));
}
let parsed = url::Url::parse(base_url)
.map_err(|e| ClientError::InvalidArgument(format!("base_url is not a valid URL: {e}")))?;
let scheme = parsed.scheme();
if scheme != "http" && scheme != "https" {
return Err(ClientError::InvalidArgument(format!(
"base_url scheme must be http or https, got: {scheme:?}"
)));
}
let path = parsed.path();
if path != "/" {
return Err(ClientError::InvalidArgument(format!(
"base_url must not have a path component, got: {path:?}"
)));
}
if parsed.query().is_some() {
return Err(ClientError::InvalidArgument(
"base_url must not have a query string".into(),
));
}
if parsed.fragment().is_some() {
return Err(ClientError::InvalidArgument(
"base_url must not have a fragment".into(),
));
}
Ok(parsed)
}
pub(crate) fn require_http_url(url: &str) -> Result<(), ClientError> {
if !is_http_or_https(url) {
return Err(ClientError::InvalidArgument(format!(
"URL must have http or https scheme, got: {url:?}"
)));
}
Ok(())
}
fn validate_session_url_schemes(session: &Session) -> Result<(), ClientError> {
for url in [
&session.api_url,
&session.upload_url,
&session.download_url,
&session.event_source_url,
] {
if !is_http_or_https(url) {
return Err(ClientError::InvalidSession(format!(
"session URL has non-http/https scheme: {url:?}"
)));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::decode_utf8_chunk;
#[test]
fn decode_utf8_chunk_all_ascii() {
let mut raw = b"hello".to_vec();
let mut buf = String::new();
decode_utf8_chunk(&mut raw, &mut buf);
assert_eq!(buf, "hello");
assert!(raw.is_empty());
}
#[test]
fn decode_utf8_chunk_complete_multibyte() {
let mut raw = "café".as_bytes().to_vec();
let mut buf = String::new();
decode_utf8_chunk(&mut raw, &mut buf);
assert_eq!(buf, "café");
assert!(raw.is_empty());
}
#[test]
fn decode_utf8_chunk_incomplete_head_retained() {
let mut raw = vec![0xC3u8]; let mut buf = String::new();
decode_utf8_chunk(&mut raw, &mut buf);
assert_eq!(buf, "", "no complete codepoints to push");
assert_eq!(raw, vec![0xC3u8], "incomplete head must stay in raw");
}
#[test]
fn decode_utf8_chunk_prefix_then_incomplete_head() {
let mut raw = vec![b'a', b'b', 0xC3u8];
let mut buf = String::new();
decode_utf8_chunk(&mut raw, &mut buf);
assert_eq!(buf, "ab");
assert_eq!(raw, vec![0xC3u8], "incomplete head must stay in raw");
}
#[test]
fn decode_utf8_chunk_split_sequence_completed() {
let mut raw = vec![0xC3u8];
let mut buf = String::new();
decode_utf8_chunk(&mut raw, &mut buf);
assert_eq!(raw, vec![0xC3u8], "incomplete head retained after chunk 1");
raw.push(0xA9u8);
decode_utf8_chunk(&mut raw, &mut buf);
assert_eq!(buf, "é", "character fully decoded after chunk 2");
assert!(raw.is_empty());
}
#[test]
fn decode_utf8_chunk_invalid_byte_drained() {
let mut raw = vec![0xFFu8];
let mut buf = String::new();
decode_utf8_chunk(&mut raw, &mut buf);
assert_eq!(buf, "");
assert!(raw.is_empty(), "definitively invalid byte must be drained");
}
#[test]
fn decode_utf8_chunk_prefix_then_invalid_drained() {
let mut raw = vec![b'a', b'b', 0xFFu8];
let mut buf = String::new();
decode_utf8_chunk(&mut raw, &mut buf);
assert_eq!(buf, "ab");
assert!(raw.is_empty(), "prefix and invalid byte must be drained");
}
}