use base64::Engine;
use bytes::{BufMut, Bytes, BytesMut};
use futures_util::{Stream, StreamExt};
use reqwest::{Client, Response, StatusCode};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::pin::Pin;
use std::time::Duration;
use thiserror::Error;
use url::Url;
#[derive(Debug, Clone, Deserialize)]
pub struct FunctionErrorDetails {
pub message: Option<String>,
pub status: Option<u16>,
pub code: Option<String>,
pub details: Option<Value>,
}
#[derive(Debug, Error)]
pub enum FunctionsError {
#[error("Request error: {0}")]
RequestError(#[from] reqwest::Error),
#[error("URL parse error: {0}")]
UrlError(#[from] url::ParseError),
#[error("JSON error: {0}")]
JsonError(#[from] serde_json::Error),
#[error("Function error (status: {status}): {message}")]
FunctionError {
message: String,
status: StatusCode,
details: Option<FunctionErrorDetails>,
},
#[error("Timeout error: Function execution exceeded timeout limit")]
TimeoutError,
#[error("Invalid response: {0}")]
InvalidResponse(String),
}
impl FunctionsError {
pub fn new(message: String) -> Self {
Self::FunctionError {
message,
status: StatusCode::INTERNAL_SERVER_ERROR,
details: None,
}
}
pub fn from_response(response: &Response) -> Self {
Self::FunctionError {
message: format!("Function returned error status: {}", response.status()),
status: response.status(),
details: None,
}
}
pub fn with_details(response: &Response, details: FunctionErrorDetails) -> Self {
Self::FunctionError {
message: details.message.as_ref().map_or_else(
|| format!("Function returned error status: {}", response.status()),
|msg| msg.clone(),
),
status: response.status(),
details: Some(details),
}
}
}
pub type Result<T> = std::result::Result<T, FunctionsError>;
#[derive(Clone, Debug)]
pub struct FunctionOptions {
pub headers: Option<HashMap<String, String>>,
pub timeout_seconds: Option<u64>,
pub response_type: ResponseType,
pub content_type: Option<String>,
}
impl Default for FunctionOptions {
fn default() -> Self {
Self {
headers: None,
timeout_seconds: None,
response_type: ResponseType::Json,
content_type: None,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ResponseType {
Json,
Text,
Binary,
Stream,
}
pub type ByteStream = Pin<Box<dyn Stream<Item = Result<Bytes>> + Send>>;
#[derive(Debug, Clone)]
pub struct FunctionResponse<T> {
pub data: T,
pub status: StatusCode,
pub headers: HashMap<String, String>,
}
pub struct FunctionsClient {
base_url: String,
api_key: String,
http_client: Client,
}
pub struct FunctionRequest<'a, T> {
client: &'a FunctionsClient,
function_name: String,
_response_type: std::marker::PhantomData<T>,
}
impl<'a, T: DeserializeOwned> FunctionRequest<'a, T> {
pub async fn execute<B: Serialize>(
&self,
body: Option<B>,
options: Option<FunctionOptions>,
) -> Result<T> {
let result = self
.client
.invoke::<T, B>(&self.function_name, body, options)
.await?;
Ok(result.data)
}
}
impl FunctionsClient {
pub fn new(supabase_url: &str, supabase_key: &str, http_client: Client) -> Self {
Self {
base_url: supabase_url.to_string(),
api_key: supabase_key.to_string(),
http_client,
}
}
pub async fn invoke<T: DeserializeOwned, B: Serialize>(
&self,
function_name: &str,
body: Option<B>,
options: Option<FunctionOptions>,
) -> Result<FunctionResponse<T>> {
let opts = options.unwrap_or_default();
let mut url = Url::parse(&self.base_url)?;
url.path_segments_mut()
.map_err(|_| FunctionsError::UrlError(url::ParseError::EmptyHost))?
.push("functions")
.push("v1")
.push(function_name);
let mut request_builder = self
.http_client
.post(url)
.header("apikey", &self.api_key)
.header("Authorization", format!("Bearer {}", &self.api_key));
if let Some(timeout) = opts.timeout_seconds {
request_builder = request_builder.timeout(Duration::from_secs(timeout));
}
if let Some(content_type) = opts.content_type {
request_builder = request_builder.header("Content-Type", content_type);
}
if let Some(headers) = opts.headers {
for (key, value) in headers {
request_builder = request_builder.header(key, value);
}
}
if let Some(body_data) = body {
request_builder = request_builder.json(&body_data);
}
let response = request_builder.send().await.map_err(|e| {
if e.is_timeout() {
FunctionsError::TimeoutError
} else {
FunctionsError::from(e)
}
})?;
let status = response.status();
if !status.is_success() {
let status_copy = status;
let error_body = response
.text()
.await
.unwrap_or_else(|_| "Failed to read error response".to_string());
if let Ok(error_details) = serde_json::from_str::<FunctionErrorDetails>(&error_body) {
return Err(FunctionsError::FunctionError {
message: error_details.message.as_ref().map_or_else(
|| format!("Function returned error status: {}", status_copy),
|msg| msg.clone(),
),
status: status_copy,
details: Some(error_details),
});
} else {
return Err(FunctionsError::FunctionError {
message: error_body,
status: status_copy,
details: None,
});
}
}
let headers = response
.headers()
.iter()
.map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("").to_string()))
.collect::<HashMap<String, String>>();
match opts.response_type {
ResponseType::Json => {
let data = response.json::<T>().await.map_err(|e| {
FunctionsError::JsonError(serde_json::from_str::<T>("{}").err().unwrap_or_else(
|| {
serde_json::Error::io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
e.to_string(),
))
},
))
})?;
Ok(FunctionResponse {
data,
status,
headers,
})
}
ResponseType::Text => {
let text = response.text().await?;
let data: T = serde_json::from_str(&text).unwrap_or_else(|_| {
panic!("Failed to deserialize text response as requested type")
});
Ok(FunctionResponse {
data,
status,
headers,
})
}
ResponseType::Binary => {
let bytes = response.bytes().await?;
let binary_str = base64::engine::general_purpose::STANDARD.encode(&bytes);
let data: T =
serde_json::from_str(&format!("\"{}\"", binary_str)).unwrap_or_else(|_| {
panic!("Failed to deserialize binary response as requested type")
});
Ok(FunctionResponse {
data,
status,
headers,
})
}
ResponseType::Stream => {
Err(FunctionsError::InvalidResponse(
"Stream response type cannot be handled by invoke(). Use invoke_stream() instead.".to_string()
))
}
}
}
pub async fn invoke_json<T: DeserializeOwned, B: Serialize>(
&self,
function_name: &str,
body: Option<B>,
) -> Result<T> {
let options = FunctionOptions {
response_type: ResponseType::Json,
..Default::default()
};
let response = self
.invoke::<T, B>(function_name, body, Some(options))
.await?;
Ok(response.data)
}
pub async fn invoke_text<B: Serialize>(
&self,
function_name: &str,
body: Option<B>,
) -> Result<String> {
let options = FunctionOptions {
response_type: ResponseType::Text,
..Default::default()
};
let mut url = Url::parse(&self.base_url)?;
url.path_segments_mut()
.map_err(|_| FunctionsError::UrlError(url::ParseError::EmptyHost))?
.push("functions")
.push("v1")
.push(function_name);
let mut request_builder = self
.http_client
.post(url)
.header("apikey", &self.api_key)
.header("Authorization", format!("Bearer {}", &self.api_key));
if let Some(timeout) = options.timeout_seconds {
request_builder = request_builder.timeout(Duration::from_secs(timeout));
}
if let Some(content_type) = options.content_type {
request_builder = request_builder.header("Content-Type", content_type);
} else {
request_builder = request_builder.header("Content-Type", "application/json");
}
request_builder = request_builder.header("Accept", "text/plain, */*;q=0.9");
if let Some(headers) = options.headers {
for (key, value) in headers {
request_builder = request_builder.header(key, value);
}
}
if let Some(body_data) = body {
request_builder = request_builder.json(&body_data);
}
let response = request_builder.send().await.map_err(|e| {
if e.is_timeout() {
FunctionsError::TimeoutError
} else {
FunctionsError::from(e)
}
})?;
let status = response.status();
if !status.is_success() {
let error_body = response
.text()
.await
.unwrap_or_else(|_| "Failed to read error response".to_string());
if let Ok(error_details) = serde_json::from_str::<FunctionErrorDetails>(&error_body) {
return Err(FunctionsError::FunctionError {
message: error_details.message.as_ref().map_or_else(
|| format!("Function returned error status: {}", status),
|msg| msg.clone(),
),
status,
details: Some(error_details),
});
} else {
return Err(FunctionsError::FunctionError {
message: error_body,
status,
details: None,
});
}
}
response.text().await.map_err(FunctionsError::from)
}
pub async fn invoke_binary<B: Serialize>(
&self,
function_name: &str,
body: Option<B>,
options: Option<FunctionOptions>,
) -> Result<Bytes> {
let options = options.unwrap_or_else(|| FunctionOptions {
response_type: ResponseType::Binary,
..Default::default()
});
let mut url = Url::parse(&self.base_url)?;
url.path_segments_mut()
.map_err(|_| FunctionsError::UrlError(url::ParseError::EmptyHost))?
.push("functions")
.push("v1")
.push(function_name);
let mut request_builder = self
.http_client
.post(url)
.header("apikey", &self.api_key)
.header("Authorization", format!("Bearer {}", &self.api_key));
if let Some(timeout) = options.timeout_seconds {
request_builder = request_builder.timeout(Duration::from_secs(timeout));
}
if let Some(content_type) = options.content_type {
request_builder = request_builder.header("Content-Type", content_type);
} else {
request_builder = request_builder.header("Content-Type", "application/json");
}
request_builder = request_builder.header("Accept", "application/octet-stream");
if let Some(headers) = options.headers {
for (key, value) in headers {
request_builder = request_builder.header(key, value);
}
}
if let Some(body_data) = body {
request_builder = request_builder.json(&body_data);
}
let response = request_builder.send().await.map_err(|e| {
if e.is_timeout() {
FunctionsError::TimeoutError
} else {
FunctionsError::from(e)
}
})?;
let status = response.status();
if !status.is_success() {
let error_body = response
.text()
.await
.unwrap_or_else(|_| "Failed to read error response".to_string());
if let Ok(error_details) = serde_json::from_str::<FunctionErrorDetails>(&error_body) {
return Err(FunctionsError::FunctionError {
message: error_details.message.as_ref().map_or_else(
|| format!("Function returned error status: {}", status),
|msg| msg.clone(),
),
status,
details: Some(error_details),
});
} else {
return Err(FunctionsError::FunctionError {
message: error_body,
status,
details: None,
});
}
}
response.bytes().await.map_err(FunctionsError::from)
}
pub async fn invoke_binary_stream<B: Serialize>(
&self,
function_name: &str,
body: Option<B>,
options: Option<FunctionOptions>,
) -> Result<ByteStream> {
let opts = options.unwrap_or_else(|| FunctionOptions {
response_type: ResponseType::Stream,
content_type: Some("application/octet-stream".to_string()),
..Default::default()
});
let mut custom_opts = opts;
let mut headers = custom_opts.headers.unwrap_or_default();
headers.insert("Accept".to_string(), "application/octet-stream".to_string());
custom_opts.headers = Some(headers);
self.invoke_stream(function_name, body, Some(custom_opts))
.await
}
pub fn process_binary_chunks<F>(
&self,
stream: ByteStream,
chunk_size: usize,
mut processor: F,
) -> Pin<Box<dyn Stream<Item = Result<Bytes>> + Send + '_>>
where
F: FnMut(&[u8]) -> std::result::Result<Bytes, String> + Send + 'static,
{
Box::pin(async_stream::stream! {
let mut buffer = BytesMut::new();
tokio::pin!(stream);
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
buffer.extend_from_slice(&chunk);
while buffer.len() >= chunk_size {
let chunk_to_process = buffer.split_to(chunk_size);
match processor(&chunk_to_process) {
Ok(processed) => yield Ok(processed),
Err(err) => {
yield Err(FunctionsError::InvalidResponse(err));
return;
}
}
}
},
Err(e) => {
yield Err(e);
return;
}
}
}
if !buffer.is_empty() {
match processor(&buffer) {
Ok(processed) => yield Ok(processed),
Err(err) => yield Err(FunctionsError::InvalidResponse(err)),
}
}
})
}
pub async fn invoke_stream<B: Serialize>(
&self,
function_name: &str,
body: Option<B>,
options: Option<FunctionOptions>,
) -> Result<ByteStream> {
let opts = options.unwrap_or_else(|| FunctionOptions {
response_type: ResponseType::Stream,
..Default::default()
});
let mut url = Url::parse(&self.base_url)?;
url.path_segments_mut()
.map_err(|_| FunctionsError::UrlError(url::ParseError::EmptyHost))?
.push("functions")
.push("v1")
.push(function_name);
let mut request_builder = self
.http_client
.post(url)
.header("apikey", &self.api_key)
.header("Authorization", format!("Bearer {}", &self.api_key));
if let Some(timeout) = opts.timeout_seconds {
request_builder = request_builder.timeout(Duration::from_secs(timeout));
}
if let Some(content_type) = opts.content_type {
request_builder = request_builder.header("Content-Type", content_type);
} else {
request_builder = request_builder.header("Content-Type", "application/json");
}
if let Some(headers) = opts.headers {
for (key, value) in headers {
request_builder = request_builder.header(key, value);
}
}
if let Some(body_data) = body {
request_builder = request_builder.json(&body_data);
}
let response = request_builder.send().await.map_err(|e| {
if e.is_timeout() {
FunctionsError::TimeoutError
} else {
FunctionsError::from(e)
}
})?;
let status = response.status();
if !status.is_success() {
let status_copy = status;
let error_body = response
.text()
.await
.unwrap_or_else(|_| "Failed to read error response".to_string());
if let Ok(error_details) = serde_json::from_str::<FunctionErrorDetails>(&error_body) {
return Err(FunctionsError::FunctionError {
message: error_details.message.as_ref().map_or_else(
|| format!("Function returned error status: {}", status_copy),
|msg| msg.clone(),
),
status: status_copy,
details: Some(error_details),
});
} else {
return Err(FunctionsError::FunctionError {
message: error_body,
status: status_copy,
details: None,
});
}
}
Ok(Box::pin(
response
.bytes_stream()
.map(|result| result.map_err(FunctionsError::from)),
))
}
pub async fn invoke_json_stream<B: Serialize>(
&self,
function_name: &str,
body: Option<B>,
options: Option<FunctionOptions>,
) -> Result<Pin<Box<dyn Stream<Item = Result<Value>> + Send + '_>>> {
let byte_stream = self.invoke_stream(function_name, body, options).await?;
let json_stream = self.byte_stream_to_json(byte_stream);
Ok(json_stream)
}
fn byte_stream_to_json(
&self,
stream: ByteStream,
) -> Pin<Box<dyn Stream<Item = Result<Value>> + Send + '_>> {
Box::pin(async_stream::stream! {
let mut line_stream = self.stream_to_lines(stream);
while let Some(line_result) = line_stream.next().await {
match line_result {
Ok(line) => {
if line.trim().is_empty() {
continue;
}
match serde_json::from_str::<Value>(&line) {
Ok(json_value) => {
yield Ok(json_value);
},
Err(err) => {
yield Err(FunctionsError::JsonError(err));
}
}
},
Err(err) => {
yield Err(err);
break;
}
}
}
})
}
pub fn stream_to_lines(
&self,
stream: ByteStream,
) -> Pin<Box<dyn Stream<Item = Result<String>> + Send + '_>> {
Box::pin(async_stream::stream! {
let mut buf = BytesMut::new();
tokio::pin!(stream);
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
buf.extend_from_slice(&chunk);
while let Some(i) = buf.iter().position(|&b| b == b'\n') {
let line = if i > 0 && buf[i - 1] == b'\r' {
let line = String::from_utf8_lossy(&buf[..i - 1]).to_string();
unsafe { buf.advance_mut(i + 1); }
line
} else {
let line = String::from_utf8_lossy(&buf[..i]).to_string();
unsafe { buf.advance_mut(i + 1); }
line
};
yield Ok(line);
}
},
Err(e) => {
yield Err(e);
break;
}
}
}
if !buf.is_empty() {
let line = String::from_utf8_lossy(&buf).to_string();
yield Ok(line);
}
})
}
pub fn create_request<T: DeserializeOwned>(
&self,
function_name: &str,
) -> FunctionRequest<'_, T> {
FunctionRequest {
client: self,
function_name: function_name.to_string(),
_response_type: std::marker::PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use super::*; use serde_json::json;
use wiremock::matchers::{body_json, header, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct TestPayload {
message: String,
}
#[tokio::test]
async fn test_invoke() {
}
#[tokio::test]
async fn test_invoke_json_success() {
let server = MockServer::start().await;
let mock_uri = server.uri();
let api_key = "test-key";
let function_name = "hello-world";
let request_body = json!({ "name": "Rust" });
let expected_response = TestPayload {
message: "Hello Rust".to_string(),
};
Mock::given(method("POST"))
.and(path(format!("/functions/v1/{}", function_name)))
.and(header("apikey", api_key))
.and(header(
"Authorization",
format!("Bearer {}", api_key).as_str(),
))
.and(header("Content-Type", "application/json"))
.and(body_json(&request_body))
.respond_with(ResponseTemplate::new(200).set_body_json(&expected_response))
.mount(&server)
.await;
let client = FunctionsClient::new(&mock_uri, api_key, reqwest::Client::new());
let result = client
.invoke_json::<TestPayload, Value>(function_name, Some(request_body))
.await;
assert!(result.is_ok());
let data = result.unwrap();
assert_eq!(data, expected_response);
server.verify().await;
}
#[tokio::test]
async fn test_invoke_json_error_with_details() {
let server = MockServer::start().await;
let mock_uri = server.uri();
let api_key = "test-key";
let function_name = "error-func";
let request_body = json!({ "input": "invalid" });
let error_response_body = json!({
"message": "Something went wrong!",
"code": "FUNC_ERROR",
"details": { "reason": "Internal failure" }
});
Mock::given(method("POST"))
.and(path(format!("/functions/v1/{}", function_name)))
.and(header("apikey", api_key))
.and(header(
"Authorization",
format!("Bearer {}", api_key).as_str(),
))
.and(body_json(&request_body))
.respond_with(
ResponseTemplate::new(500)
.set_body_json(&error_response_body)
.insert_header("Content-Type", "application/json"),
)
.mount(&server)
.await;
let client = FunctionsClient::new(&mock_uri, api_key, reqwest::Client::new());
let result = client
.invoke_json::<Value, Value>(function_name, Some(request_body))
.await;
assert!(result.is_err());
match result.err().unwrap() {
FunctionsError::FunctionError {
message,
status,
details,
} => {
assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(message, "Something went wrong!");
assert!(details.is_some());
let details_unwrapped = details.unwrap();
assert_eq!(
details_unwrapped.message,
Some("Something went wrong!".to_string())
);
assert_eq!(details_unwrapped.code, Some("FUNC_ERROR".to_string()));
assert!(details_unwrapped.details.is_some());
assert_eq!(
details_unwrapped.details.unwrap(),
json!({ "reason": "Internal failure" })
);
}
_ => panic!("Expected FunctionError, got different error type"),
}
server.verify().await;
}
#[tokio::test]
async fn test_invoke_text_success() {
let server = MockServer::start().await;
let mock_uri = server.uri();
let api_key = "test-key";
let function_name = "plain-text-func";
let request_body = json!({ "format": "text" });
let expected_response_text = "This is a plain text response.";
Mock::given(method("POST"))
.and(path(format!("/functions/v1/{}", function_name)))
.and(header("apikey", api_key))
.and(header(
"Authorization",
format!("Bearer {}", api_key).as_str(),
))
.and(header("Content-Type", "application/json")) .and(body_json(&request_body))
.respond_with(
ResponseTemplate::new(200)
.set_body_string(expected_response_text)
.insert_header("Content-Type", "text/plain"), )
.mount(&server)
.await;
let client = FunctionsClient::new(&mock_uri, api_key, reqwest::Client::new());
let result = client
.invoke_text::<Value>(function_name, Some(request_body)) .await;
assert!(result.is_ok());
let data = result.unwrap();
assert_eq!(data, expected_response_text);
server.verify().await;
}
}