use std::sync::Arc;
use async_trait::async_trait;
use llmsdk_provider::ProviderError;
use llmsdk_provider::error::Result;
use llmsdk_provider::shared::FileBytes;
use llmsdk_provider::{FilesModel, UploadFileData, UploadFileOptions, UploadFileResult};
use llmsdk_provider_utils::http::{RawRequest, post_raw};
use llmsdk_provider_utils::multipart::Multipart;
use serde_json::{Map as JsonMap, Value as JsonValue};
use crate::PROVIDER_ID;
use crate::config::Inner;
use crate::files::options::parse_xai_files_options;
use crate::files::wire::WireUploadResponse;
const DEFAULT_FILENAME: &str = "blob";
#[derive(Debug, Clone)]
pub struct XaiFiles {
inner: Arc<Inner>,
provider: String,
}
impl XaiFiles {
pub(crate) fn new(inner: Arc<Inner>) -> Self {
Self {
inner,
provider: format!("{PROVIDER_ID}.files"),
}
}
fn endpoint(&self) -> String {
format!("{}/files", self.inner.base_url)
}
}
#[async_trait]
impl FilesModel for XaiFiles {
fn provider(&self) -> &str {
&self.provider
}
async fn upload_file(&self, options: UploadFileOptions) -> Result<UploadFileResult> {
let xai_opts = parse_xai_files_options(options.provider_options.as_ref())?;
let bytes = upload_data_to_bytes(&options.data)?;
let filename_for_form = options
.filename
.clone()
.unwrap_or_else(|| DEFAULT_FILENAME.to_owned());
let mut mp = Multipart::new();
mp.file(
"file",
&filename_for_form,
Some(&options.media_type),
&bytes,
);
if let Some(team_id) = xai_opts.team_id.as_deref() {
mp.text("team_id", team_id);
}
let (boundary, body) = mp.finish();
let content_type = format!("multipart/form-data; boundary={boundary}");
let mut req = RawRequest::new(self.endpoint(), body, content_type);
req.headers = self.inner.headers.clone();
let envelope = post_raw::<WireUploadResponse>(&self.inner.http, req).await?;
let resp = envelope.value;
let mut provider_reference = std::collections::HashMap::new();
provider_reference.insert(PROVIDER_ID.to_owned(), resp.id.clone());
let echo_filename = resp.filename.clone().or_else(|| options.filename.clone());
let echo_media_type = Some(options.media_type.clone());
let mut meta_obj = JsonMap::new();
if let Some(f) = resp.filename.as_ref() {
meta_obj.insert("filename".to_owned(), JsonValue::String(f.clone()));
}
if let Some(b) = resp.bytes {
meta_obj.insert("bytes".to_owned(), JsonValue::Number(b.into()));
}
if let Some(t) = resp.created_at {
meta_obj.insert("createdAt".to_owned(), JsonValue::Number(t.into()));
}
let mut provider_metadata = std::collections::HashMap::new();
provider_metadata.insert(PROVIDER_ID.to_owned(), meta_obj);
Ok(UploadFileResult {
provider_reference,
media_type: echo_media_type,
filename: echo_filename,
provider_metadata: Some(provider_metadata),
warnings: Vec::new(),
})
}
}
fn upload_data_to_bytes(data: &UploadFileData) -> Result<Vec<u8>> {
match data {
UploadFileData::Data { data: bytes } => match bytes {
FileBytes::Bytes(b) => Ok(b.clone()),
FileBytes::Base64(s) => base64_decode(s).map_err(|err| {
ProviderError::type_validation(
"data.data",
JsonValue::String(s.clone()),
format!("invalid base64: {err}"),
)
}),
},
UploadFileData::Text { text } => Ok(text.clone().into_bytes()),
}
}
fn base64_decode(input: &str) -> std::result::Result<Vec<u8>, Base64Error> {
let bytes = input.as_bytes();
if !bytes.len().is_multiple_of(4) {
return Err(Base64Error::Length);
}
let mut out = Vec::with_capacity(bytes.len() / 4 * 3);
for chunk in bytes.chunks_exact(4) {
let (b0, p0) = decode_byte(chunk[0])?;
let (b1, p1) = decode_byte(chunk[1])?;
let (b2, p2) = decode_byte(chunk[2])?;
let (b3, p3) = decode_byte(chunk[3])?;
if p0 || p1 {
return Err(Base64Error::Padding);
}
let n =
(u32::from(b0) << 18) | (u32::from(b1) << 12) | (u32::from(b2) << 6) | u32::from(b3);
out.push(((n >> 16) & 0xFF) as u8);
if !p2 {
out.push(((n >> 8) & 0xFF) as u8);
}
if !p3 {
if p2 {
return Err(Base64Error::Padding);
}
out.push((n & 0xFF) as u8);
}
}
Ok(out)
}
fn decode_byte(b: u8) -> std::result::Result<(u8, bool), Base64Error> {
Ok(match b {
b'A'..=b'Z' => (b - b'A', false),
b'a'..=b'z' => (b - b'a' + 26, false),
b'0'..=b'9' => (b - b'0' + 52, false),
b'+' => (62, false),
b'/' => (63, false),
b'=' => (0, true),
_ => return Err(Base64Error::Byte(b)),
})
}
#[derive(Debug)]
enum Base64Error {
Length,
Padding,
Byte(u8),
}
impl std::fmt::Display for Base64Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Length => f.write_str("input length is not a multiple of 4"),
Self::Padding => f.write_str("misplaced padding"),
Self::Byte(b) => write!(f, "non-alphabet byte 0x{b:02x}"),
}
}
}
impl std::error::Error for Base64Error {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn data_bytes_passes_through() {
let r = upload_data_to_bytes(&UploadFileData::Data {
data: FileBytes::Bytes(vec![1, 2, 3]),
})
.expect("decodes");
assert_eq!(r, vec![1, 2, 3]);
}
#[test]
fn data_base64_decodes() {
let r = upload_data_to_bytes(&UploadFileData::Data {
data: FileBytes::Base64("dGVzdA==".into()),
})
.expect("decodes");
assert_eq!(r, b"test");
}
#[test]
fn data_base64_rejects_invalid() {
let err = upload_data_to_bytes(&UploadFileData::Data {
data: FileBytes::Base64("not_padded".into()),
})
.unwrap_err();
assert!(format!("{err}").contains("base64"));
}
#[test]
fn text_encodes_utf8() {
let r = upload_data_to_bytes(&UploadFileData::Text {
text: "héllo".into(),
})
.expect("decodes");
assert_eq!(r, "héllo".as_bytes());
}
}