use rama_core::bytes::{BufMut, Bytes, BytesMut};
use rama_core::error::{BoxError, ErrorContext as _, ErrorExt as _};
use rama_core::futures::{StreamExt, TryStreamExt, stream};
use rama_core::stream::io::ReaderStream;
use rama_core::telemetry::tracing;
use rama_http_types::{HeaderMap, HeaderValue, header, mime};
use rama_utils::collections::smallvec::SmallVec;
use rama_utils::macros::generate_set_and_with;
use rama_utils::str::smol_str::{SmolStr, format_smolstr};
use rand::RngExt as _;
use std::borrow::Cow;
use std::path::Path;
use std::pin::Pin;
use tokio::io::AsyncReadExt as _;
const PARTS_INLINE_CAP: usize = 4;
const CRLF: &[u8] = b"\r\n";
const DASH_DASH: &[u8] = b"--";
const FIELD_DISPOSITION_PREFIX: &[u8] = b"Content-Disposition: form-data; name=\"";
const FILENAME_PREFIX: &[u8] = b"; filename=\"";
const CONTENT_TYPE_PREFIX: &[u8] = b"Content-Type: ";
const QUOTE: &[u8] = b"\"";
const HEADER_KV_SEP: &[u8] = b": ";
type ChunkStream = Pin<Box<dyn rama_core::futures::Stream<Item = Result<Bytes, BoxError>> + Send>>;
#[derive(Debug)]
#[must_use]
pub struct Form {
boundary: SmolStr,
parts: SmallVec<[NamedPart; PARTS_INLINE_CAP]>,
}
#[derive(Debug)]
struct NamedPart {
name: Cow<'static, str>,
part: Part,
}
impl Default for Form {
fn default() -> Self {
Self::new()
}
}
impl Form {
pub fn new() -> Self {
Self {
boundary: gen_boundary(),
parts: SmallVec::new(),
}
}
#[must_use]
pub fn boundary(&self) -> &str {
&self.boundary
}
#[must_use]
#[expect(
clippy::unreachable,
reason = "boundary is constructed from validated bytes; HeaderValue::try_from is infallible by construction here, the Err arm exists only to satisfy unwrap_used/expect_used"
)]
pub fn content_type(&self) -> HeaderValue {
let value = format!("multipart/form-data; boundary={}", self.boundary);
match HeaderValue::try_from(value) {
Ok(v) => v,
Err(_) => unreachable!("multipart boundary always converts to a HeaderValue"),
}
}
pub fn text<N, V>(self, name: N, value: V) -> Self
where
N: Into<Cow<'static, str>>,
V: Into<Cow<'static, str>>,
{
self.part(name, Part::text(value))
}
pub fn bytes<N, B>(self, name: N, value: B) -> Self
where
N: Into<Cow<'static, str>>,
B: Into<Bytes>,
{
self.part(name, Part::bytes(value))
}
pub async fn file<N, P>(self, name: N, path: P) -> std::io::Result<Self>
where
N: Into<Cow<'static, str>>,
P: AsRef<Path>,
{
let part = Part::file(path).await?;
Ok(self.part(name, part))
}
pub async fn with_field_spec(self, spec: &str) -> Result<Self, FieldSpecError> {
let parsed = FieldSpec::parse(spec)?;
let name = parsed.name.to_owned();
let part = parsed.into_part().await?;
Ok(self.part(name, part))
}
pub fn part<N>(mut self, name: N, part: Part) -> Self
where
N: Into<Cow<'static, str>>,
{
self.parts.push(NamedPart {
name: name.into(),
part,
});
self
}
#[must_use]
pub fn content_length(&self) -> Option<u64> {
let mut total: u64 = 0;
for np in &self.parts {
let part_size = np.part.content_size?;
let header_len = part_headers_len(&self.boundary, &np.name, &np.part) as u64;
total = total.checked_add(header_len)?;
total = total.checked_add(part_size)?;
total = total.checked_add(CRLF.len() as u64)?;
}
let trailer_len =
(DASH_DASH.len() + self.boundary.len() + DASH_DASH.len() + CRLF.len()) as u64;
total = total.checked_add(trailer_len)?;
Some(total)
}
pub fn into_stream(
self,
) -> impl rama_core::futures::Stream<Item = Result<Bytes, BoxError>> + Send {
let boundary = self.boundary;
let n_parts = self.parts.len();
let trailer = {
let cap = if n_parts == 0 { 0 } else { CRLF.len() }
+ DASH_DASH.len()
+ boundary.len()
+ DASH_DASH.len()
+ CRLF.len();
let mut buf = BytesMut::with_capacity(cap);
if n_parts > 0 {
buf.put_slice(CRLF);
}
buf.put_slice(DASH_DASH);
buf.put_slice(boundary.as_bytes());
buf.put_slice(DASH_DASH);
buf.put_slice(CRLF);
buf.freeze()
};
let mut chunks: Vec<ChunkStream> = Vec::with_capacity(n_parts * 2 + 1);
for (i, np) in self.parts.into_iter().enumerate() {
let framing = render_framing(&boundary, &np.name, &np.part, i > 0);
chunks.push(Box::pin(stream::iter([Ok::<Bytes, BoxError>(framing)])));
chunks.push(match np.part.body {
PartBody::Bytes(b) => Box::pin(stream::iter([Ok::<Bytes, BoxError>(b)])),
PartBody::Stream(s) => s,
});
}
chunks.push(Box::pin(stream::iter([Ok::<Bytes, BoxError>(trailer)])));
stream::iter(chunks).flatten()
}
pub fn into_body(self) -> crate::Body {
crate::Body::from_stream(self.into_stream())
}
}
#[must_use]
pub struct Part {
body: PartBody,
content_size: Option<u64>,
file_name: Option<Cow<'static, str>>,
mime: Option<mime::Mime>,
headers: HeaderMap,
}
impl std::fmt::Debug for Part {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Part")
.field(
"body_kind",
&match &self.body {
PartBody::Bytes(b) => format!("bytes ({} B)", b.len()),
PartBody::Stream(_) => String::from("stream"),
},
)
.field("content_size", &self.content_size)
.field("file_name", &self.file_name)
.field("mime", &self.mime.as_ref().map(mime::Mime::essence_str))
.field("headers", &self.headers)
.finish()
}
}
enum PartBody {
Bytes(Bytes),
Stream(ChunkStream),
}
impl Part {
pub fn text<V: Into<Cow<'static, str>>>(value: V) -> Self {
let bytes = Bytes::from(value.into().into_owned().into_bytes());
let len = bytes.len() as u64;
Self {
body: PartBody::Bytes(bytes),
content_size: Some(len),
file_name: None,
mime: None,
headers: HeaderMap::new(),
}
}
pub fn bytes<B: Into<Bytes>>(value: B) -> Self {
let bytes: Bytes = value.into();
let len = bytes.len() as u64;
Self {
body: PartBody::Bytes(bytes),
content_size: Some(len),
file_name: None,
mime: None,
headers: HeaderMap::new(),
}
}
pub fn stream<S, O, E>(stream: S) -> Self
where
S: rama_core::futures::Stream<Item = Result<O, E>> + Send + 'static,
O: Into<Bytes> + 'static,
E: Into<BoxError> + 'static,
{
let mapped = stream.map_ok(Into::into).map_err(Into::into);
Self {
body: PartBody::Stream(Box::pin(mapped)),
content_size: None,
file_name: None,
mime: None,
headers: HeaderMap::new(),
}
}
pub async fn file<P: AsRef<Path>>(path: P) -> std::io::Result<Self> {
let path = path.as_ref();
let file_name: Option<Cow<'static, str>> = path
.file_name()
.map(|name| Cow::Owned(name.to_string_lossy().into_owned()));
let mime = path
.extension()
.and_then(std::ffi::OsStr::to_str)
.and_then(|ext| mime_guess::from_ext(ext).first())
.unwrap_or(mime::APPLICATION_OCTET_STREAM);
let file = rama_utils::fs::safe_open(path).await?;
let metadata = file.metadata().await?;
let len = metadata.len();
tracing::debug!(
path = %path.display(),
size = len,
mime = %mime,
"multipart::Part::file: opened file for streaming",
);
let stream = ReaderStream::new(file);
let mapped = stream.map_ok(Bytes::from).map_err(BoxError::from);
Ok(Self {
body: PartBody::Stream(Box::pin(mapped)),
content_size: Some(len),
file_name,
mime: Some(mime),
headers: HeaderMap::new(),
})
}
generate_set_and_with! {
pub fn file_name(mut self, file_name: impl Into<Cow<'static, str>>) -> Self {
self.file_name = Some(file_name.into());
self
}
}
generate_set_and_with! {
pub fn mime(mut self, mime: Option<mime::Mime>) -> Self {
self.mime = mime;
self
}
}
generate_set_and_with! {
pub fn mime_str(mut self, mime_str: &str) -> Result<Self, mime::FromStrError> {
self.mime = Some(mime_str.parse()?);
Ok(self)
}
}
generate_set_and_with! {
pub fn content_size(mut self, size: Option<u64>) -> Self {
self.content_size = size;
self
}
}
generate_set_and_with! {
pub fn headers(mut self, headers: HeaderMap) -> Self {
self.headers = headers;
self
}
}
}
fn render_framing(boundary: &str, name: &str, part: &Part, with_leading_crlf: bool) -> Bytes {
let cap =
if with_leading_crlf { CRLF.len() } else { 0 } + part_headers_len(boundary, name, part);
let mut buf = BytesMut::with_capacity(cap);
if with_leading_crlf {
buf.put_slice(CRLF);
}
buf.put_slice(DASH_DASH);
buf.put_slice(boundary.as_bytes());
buf.put_slice(CRLF);
buf.put_slice(FIELD_DISPOSITION_PREFIX);
write_quoted(&mut buf, name);
buf.put_slice(QUOTE);
if let Some(file_name) = part.file_name.as_deref() {
buf.put_slice(FILENAME_PREFIX);
write_quoted(&mut buf, file_name);
buf.put_slice(QUOTE);
}
buf.put_slice(CRLF);
if let Some(mime) = &part.mime {
buf.put_slice(CONTENT_TYPE_PREFIX);
buf.put_slice(mime.as_ref().as_bytes());
buf.put_slice(CRLF);
}
for (name, value) in &part.headers {
if name == header::CONTENT_DISPOSITION || name == header::CONTENT_TYPE {
continue;
}
buf.put_slice(name.as_str().as_bytes());
buf.put_slice(HEADER_KV_SEP);
buf.put_slice(value.as_bytes());
buf.put_slice(CRLF);
}
buf.put_slice(CRLF);
buf.freeze()
}
fn part_headers_len(boundary: &str, name: &str, part: &Part) -> usize {
let mut len = DASH_DASH.len()
+ boundary.len()
+ CRLF.len()
+ FIELD_DISPOSITION_PREFIX.len()
+ quoted_len(name)
+ QUOTE.len();
if let Some(file_name) = part.file_name.as_deref() {
len += FILENAME_PREFIX.len() + quoted_len(file_name) + QUOTE.len();
}
len += CRLF.len();
if let Some(mime) = &part.mime {
len += CONTENT_TYPE_PREFIX.len() + mime.as_ref().len() + CRLF.len();
}
for (h_name, h_value) in &part.headers {
if h_name == header::CONTENT_DISPOSITION || h_name == header::CONTENT_TYPE {
continue;
}
len += h_name.as_str().len() + HEADER_KV_SEP.len() + h_value.as_bytes().len() + CRLF.len();
}
len += CRLF.len();
len
}
fn quoted_len(s: &str) -> usize {
s.bytes()
.map(|b| match b {
b'"' | b'\\' => 2,
_ => 1,
})
.sum()
}
fn write_quoted(buf: &mut BytesMut, s: &str) {
for byte in s.as_bytes() {
match *byte {
b'"' | b'\\' => {
buf.put_u8(b'\\');
buf.put_u8(*byte);
}
b'\r' | b'\n' => {
buf.put_u8(b' ');
}
b => buf.put_u8(b),
}
}
}
#[derive(Debug, Clone)]
pub struct FieldSpec<'a> {
pub name: &'a str,
pub source: FieldSpecSource<'a>,
pub content_type: Option<&'a str>,
pub filename: Option<&'a str>,
}
#[derive(Debug, Clone)]
pub enum FieldSpecSource<'a> {
Text(&'a str),
File(&'a str),
FileText(&'a str),
}
#[derive(Debug)]
pub enum FieldSpecError {
MissingSeparator,
EmptyName,
InvalidModifier(BoxError),
}
impl std::fmt::Display for FieldSpecError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::MissingSeparator => write!(f, "field spec is missing `=` separator"),
Self::EmptyName => write!(f, "field spec has empty name"),
Self::InvalidModifier(err) => write!(f, "invalid field spec: {err}"),
}
}
}
impl std::error::Error for FieldSpecError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::InvalidModifier(err) => Some(&**err),
_ => None,
}
}
}
#[derive(Debug)]
struct InlineErr(SmolStr);
impl std::fmt::Display for InlineErr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
impl std::error::Error for InlineErr {}
impl<'a> FieldSpec<'a> {
pub fn parse(spec: &'a str) -> Result<Self, FieldSpecError> {
let (name, rest) = spec
.split_once('=')
.ok_or(FieldSpecError::MissingSeparator)?;
if name.is_empty() {
return Err(FieldSpecError::EmptyName);
}
let mut content_type: Option<&str> = None;
let mut filename: Option<&str> = None;
let value_part: &str;
if let Some((value, modifiers)) = split_modifiers(rest) {
value_part = value;
for modifier in modifiers.split(';') {
let modifier = modifier.trim();
if modifier.is_empty() {
continue;
}
let (key, val) = modifier.split_once('=').ok_or_else(|| {
FieldSpecError::InvalidModifier(
InlineErr(format_smolstr!("missing `=` in modifier `{modifier}`")).into(),
)
})?;
match key.trim() {
"type" => content_type = Some(val),
"filename" => filename = Some(val),
other => {
return Err(FieldSpecError::InvalidModifier(
InlineErr(format_smolstr!("unknown modifier key `{other}`")).into(),
));
}
}
}
} else {
value_part = rest;
}
let source = if let Some(path) = value_part.strip_prefix('@') {
FieldSpecSource::File(path)
} else if let Some(path) = value_part.strip_prefix('<') {
FieldSpecSource::FileText(path)
} else {
FieldSpecSource::Text(value_part)
};
Ok(Self {
name,
source,
content_type,
filename,
})
}
pub async fn into_part(self) -> Result<Part, FieldSpecError> {
let mut part = match self.source {
FieldSpecSource::Text(s) => Part::text(s.to_owned()),
FieldSpecSource::File("-") => Part::stream(read_stdin_stream()),
FieldSpecSource::File(path) => Part::file(path)
.await
.with_context(|| format_smolstr!("multipart field spec: open file `{path}`"))
.map_err(|e| FieldSpecError::InvalidModifier(e.into_box_error()))?,
FieldSpecSource::FileText("-") => {
let s = read_stdin_to_string()
.await
.context("multipart field spec: read stdin as text")
.map_err(|e| FieldSpecError::InvalidModifier(e.into_box_error()))?;
Part::text(s)
}
FieldSpecSource::FileText(path) => {
let s = tokio::fs::read_to_string(path)
.await
.with_context(|| format_smolstr!("multipart field spec: read file `{path}`"))
.map_err(|e| FieldSpecError::InvalidModifier(e.into_box_error()))?;
Part::text(s)
}
};
if let Some(ct) = self.content_type {
part.try_set_mime_str(ct)
.with_context(|| format_smolstr!("invalid `;type=` mime in field spec: {ct}"))
.map_err(|e| FieldSpecError::InvalidModifier(e.into_box_error()))?;
}
if let Some(fname) = self.filename {
part.set_file_name(fname.to_owned());
}
Ok(part)
}
}
fn split_modifiers(input: &str) -> Option<(&str, &str)> {
input.split_once(';')
}
async fn read_stdin_to_string() -> Result<String, BoxError> {
let mut buf = String::new();
tokio::io::stdin()
.read_to_string(&mut buf)
.await
.context("read multipart field value from stdin")?;
Ok(buf)
}
fn read_stdin_stream() -> impl rama_core::futures::Stream<Item = Result<Bytes, BoxError>> + Send {
ReaderStream::new(tokio::io::stdin())
.map_ok(Bytes::from)
.map_err(BoxError::from)
}
fn gen_boundary() -> SmolStr {
let mut rng = rand::rng();
format_smolstr!(
"{:016x}-{:016x}-{:016x}-{:016x}",
rng.random::<u64>(),
rng.random::<u64>(),
rng.random::<u64>(),
rng.random::<u64>(),
)
}
#[cfg(test)]
mod test {
use super::*;
use rama_core::futures::TryStreamExt;
async fn collect(form: Form) -> (HeaderValue, Option<u64>, Vec<u8>) {
let ct = form.content_type();
let len = form.content_length();
let bytes: Vec<u8> = form
.into_stream()
.map_ok(|chunk| chunk.to_vec())
.try_collect::<Vec<Vec<u8>>>()
.await
.unwrap()
.into_iter()
.flatten()
.collect();
(ct, len, bytes)
}
#[tokio::test]
async fn test_form_text_only() {
let form = Form::new().text("name", "glen").text("language", "rust");
let boundary = form.boundary().to_owned();
let (ct, len, bytes) = collect(form).await;
assert!(ct.to_str().unwrap().contains(&boundary));
assert_eq!(len.unwrap() as usize, bytes.len());
let s = std::str::from_utf8(&bytes).unwrap();
assert!(s.contains("name=\"name\""));
assert!(s.contains("name=\"language\""));
assert!(s.contains("\r\nglen\r\n"));
assert!(s.contains("\r\nrust\r\n"));
assert!(s.ends_with("--\r\n"));
}
#[tokio::test]
async fn test_form_bytes_with_filename_and_mime() {
let part = Part::bytes(b"\x00\x01\x02".as_slice())
.with_file_name("a.bin")
.with_mime(mime::APPLICATION_OCTET_STREAM);
let form = Form::new().part("avatar", part);
let (_, len, bytes) = collect(form).await;
assert!(len.is_some());
let s = std::str::from_utf8(&bytes[..bytes.iter().position(|&b| b == 0).unwrap()]).unwrap();
assert!(s.contains("filename=\"a.bin\""));
assert!(s.contains("Content-Type: application/octet-stream"));
}
#[tokio::test]
async fn test_form_unknown_length_when_streaming() {
let part = Part::stream(stream::iter([
Ok::<Bytes, BoxError>(Bytes::from_static(b"hello ")),
Ok::<Bytes, BoxError>(Bytes::from_static(b"world")),
]));
let form = Form::new().part("payload", part);
assert!(form.content_length().is_none());
let (_, _len, bytes) = collect(form).await;
let s = std::str::from_utf8(&bytes).unwrap();
assert!(s.contains("hello world"));
}
#[tokio::test]
async fn test_form_known_length_when_streaming_with_content_size() {
let part = Part::stream(stream::iter([Ok::<Bytes, BoxError>(Bytes::from_static(
b"abcdef",
))]))
.with_content_size(6);
let form = Form::new().part("payload", part);
let len = form.content_length().expect("length known");
let (_, _, bytes) = collect(form).await;
assert_eq!(len as usize, bytes.len());
}
#[tokio::test]
async fn test_form_quoting_escapes_quotes() {
let form = Form::new().text("we\"ird", "v");
let (_, _, bytes) = collect(form).await;
let s = std::str::from_utf8(&bytes).unwrap();
assert!(s.contains("name=\"we\\\"ird\""));
}
#[tokio::test]
async fn test_form_preserves_mime_parameters() {
let part = Part::bytes(b"hi".as_slice())
.try_with_mime_str("text/plain; charset=utf-8")
.unwrap();
let form = Form::new().part("note", part);
let len = form.content_length().expect("length known");
let (_, _, bytes) = collect(form).await;
assert_eq!(len as usize, bytes.len());
let s = std::str::from_utf8(&bytes).unwrap();
assert!(
s.contains("Content-Type: text/plain; charset=utf-8"),
"rendered body: {s}"
);
}
#[test]
fn test_field_spec_text() {
let s = FieldSpec::parse("name=glen").unwrap();
assert_eq!(s.name, "name");
assert!(matches!(s.source, FieldSpecSource::Text("glen")));
assert!(s.content_type.is_none());
assert!(s.filename.is_none());
}
#[test]
fn test_field_spec_file_with_modifiers() {
let s = FieldSpec::parse("avatar=@./photo.png;type=image/png;filename=me.png").unwrap();
assert_eq!(s.name, "avatar");
assert!(matches!(s.source, FieldSpecSource::File("./photo.png")));
assert_eq!(s.content_type, Some("image/png"));
assert_eq!(s.filename, Some("me.png"));
}
#[test]
fn test_field_spec_file_text() {
let s = FieldSpec::parse("greeting=<hello.txt").unwrap();
assert_eq!(s.name, "greeting");
assert!(matches!(s.source, FieldSpecSource::FileText("hello.txt")));
}
#[test]
fn test_field_spec_stdin() {
let s = FieldSpec::parse("blob=@-").unwrap();
assert!(matches!(s.source, FieldSpecSource::File("-")));
}
#[test]
fn test_field_spec_errors() {
assert!(matches!(
FieldSpec::parse("noequal"),
Err(FieldSpecError::MissingSeparator)
));
assert!(matches!(
FieldSpec::parse("=value"),
Err(FieldSpecError::EmptyName)
));
assert!(matches!(
FieldSpec::parse("name=v;invalid"),
Err(FieldSpecError::InvalidModifier(_))
));
assert!(matches!(
FieldSpec::parse("name=v;weird=val"),
Err(FieldSpecError::InvalidModifier(_))
));
}
#[tokio::test]
async fn test_form_with_field_spec_text() {
let form = Form::new()
.with_field_spec("name=glen")
.await
.unwrap()
.with_field_spec("lang=rust")
.await
.unwrap();
let (_, _, bytes) = collect(form).await;
let s = std::str::from_utf8(&bytes).unwrap();
assert!(s.contains("name=\"name\""));
assert!(s.contains("\r\nglen\r\n"));
assert!(s.contains("name=\"lang\""));
assert!(s.contains("\r\nrust\r\n"));
}
#[tokio::test]
async fn test_form_with_field_spec_file() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("hello.txt");
tokio::fs::write(&path, b"hi from disk").await.unwrap();
let spec = format!("note=@{};type=text/plain", path.display());
let form = Form::new().with_field_spec(&spec).await.unwrap();
let (_, _, bytes) = collect(form).await;
let s = std::str::from_utf8(&bytes).unwrap();
assert!(s.contains("name=\"note\""));
assert!(s.contains("filename=\"hello.txt\""));
assert!(s.contains("Content-Type: text/plain"));
assert!(s.contains("hi from disk"));
}
#[tokio::test]
async fn test_field_spec_file_text_allows_parent_dir_paths() {
let dir = tempfile::tempdir().unwrap();
let child = dir.path().join("child");
tokio::fs::create_dir(&child).await.unwrap();
let payload = dir.path().join("payload.txt");
tokio::fs::write(&payload, b"hello parent").await.unwrap();
let spec = format!("greeting=<{}", child.join("../payload.txt").display());
let form = Form::new().with_field_spec(&spec).await.unwrap();
let (_, _, bytes) = collect(form).await;
let s = std::str::from_utf8(&bytes).unwrap();
assert!(s.contains("\r\nhello parent\r\n"));
}
}