use crate::{
error::{Error, Result},
models::*,
};
use bytes::Bytes;
use futures::stream::{Stream, StreamExt};
use reqwest::Client;
use serde_json::json;
use std::pin::Pin;
#[derive(Debug, Clone)]
pub enum Input {
Text(String),
Object(InputObject),
}
impl From<String> for Input {
fn from(s: String) -> Self {
Input::Text(s)
}
}
impl From<&str> for Input {
fn from(s: &str) -> Self {
Input::Text(s.to_string())
}
}
impl From<InputObject> for Input {
fn from(obj: InputObject) -> Self {
Input::Object(obj)
}
}
impl From<Vec<Message>> for Input {
fn from(messages: Vec<Message>) -> Self {
Input::Object(InputObject::new(messages))
}
}
struct ParsedInput {
messages: Vec<Message>,
tools: Option<Vec<Tool>>,
tool_choice: Option<serde_json::Value>,
tags: Option<Vec<String>>,
compression_model: Option<String>,
}
#[derive(Debug, Clone)]
pub struct Edgee {
config: EdgeeConfig,
client: Client,
}
impl Edgee {
pub fn new(config: EdgeeConfig) -> Self {
Self {
config,
client: Client::new(),
}
}
pub fn from_env() -> Result<Self> {
let config = EdgeeConfig::from_env()?;
Ok(Self::new(config))
}
pub fn with_api_key(api_key: impl Into<String>) -> Self {
Self::new(EdgeeConfig::new(api_key))
}
pub async fn send(
&self,
model: impl Into<String>,
input: impl Into<Input>,
) -> Result<SendResponse> {
let input = input.into();
let parsed = self.parse_input(input);
let mut body = json!({
"model": model.into(),
"messages": parsed.messages,
"stream": false,
});
if let Some(tools) = parsed.tools {
body["tools"] = json!(tools);
}
if let Some(tool_choice) = parsed.tool_choice {
body["tool_choice"] = tool_choice;
}
if let Some(tags) = parsed.tags {
body["tags"] = json!(tags);
}
if let Some(compression_model) = &parsed.compression_model {
body["compression_model"] = json!(compression_model);
}
let response = self
.client
.post(format!("{}/v1/chat/completions", self.config.base_url))
.header("Authorization", format!("Bearer {}", self.config.api_key))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await?;
if !response.status().is_success() {
let status = response.status().as_u16();
let message = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(Error::Api { status, message });
}
let send_response: SendResponse = response.json().await?;
Ok(send_response)
}
pub async fn stream(
&self,
model: impl Into<String>,
input: impl Into<Input>,
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send>>> {
let input = input.into();
let parsed = self.parse_input(input);
let mut body = json!({
"model": model.into(),
"messages": parsed.messages,
"stream": true,
});
if let Some(tools) = parsed.tools {
body["tools"] = json!(tools);
}
if let Some(tool_choice) = parsed.tool_choice {
body["tool_choice"] = tool_choice;
}
if let Some(tags) = parsed.tags {
body["tags"] = json!(tags);
}
if let Some(compression_model) = &parsed.compression_model {
body["compression_model"] = json!(compression_model);
}
let response = self
.client
.post(format!("{}/v1/chat/completions", self.config.base_url))
.header("Authorization", format!("Bearer {}", self.config.api_key))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await?;
if !response.status().is_success() {
let status = response.status().as_u16();
let message = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(Error::Api { status, message });
}
let stream = response.bytes_stream();
let parsed_stream = Self::parse_sse_stream(stream);
Ok(Box::pin(parsed_stream))
}
fn parse_sse_stream(
stream: impl Stream<Item = reqwest::Result<Bytes>> + Send + 'static,
) -> impl Stream<Item = Result<StreamChunk>> + Send {
let mut buffer = String::new();
stream
.map(move |result| {
let bytes = result.map_err(Error::Http)?;
let text = String::from_utf8_lossy(&bytes);
buffer.push_str(&text);
let mut chunks = Vec::new();
while let Some(pos) = buffer.find("\n\n") {
let chunk = buffer[..pos].to_string();
buffer.drain(..pos + 2);
if chunk.is_empty() {
continue;
}
for line in chunk.lines() {
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
continue;
}
match serde_json::from_str::<StreamChunk>(data) {
Ok(parsed_chunk) => chunks.push(Ok(parsed_chunk)),
Err(e) => {
eprintln!("Failed to parse chunk: {}", e);
}
}
}
}
}
Ok(chunks)
})
.flat_map(|result: Result<Vec<Result<StreamChunk>>>| match result {
Ok(chunks) => futures::stream::iter(chunks).boxed(),
Err(e) => futures::stream::once(async move { Err(e) }).boxed(),
})
}
fn parse_input(&self, input: Input) -> ParsedInput {
match input {
Input::Text(text) => ParsedInput {
messages: vec![Message::user(text)],
tools: None,
tool_choice: None,
tags: None,
compression_model: None,
},
Input::Object(obj) => ParsedInput {
messages: obj.messages,
tools: obj.tools,
tool_choice: obj.tool_choice,
tags: obj.tags,
compression_model: obj.compression_model,
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_input_conversions() {
let _input: Input = "hello".into();
let _input: Input = "hello".to_string().into();
let _input: Input = InputObject::new(vec![Message::user("hello")]).into();
let _input: Input = vec![Message::user("hello")].into();
}
#[test]
fn test_config_from_env() {
unsafe {
std::env::set_var("EDGEE_API_KEY", "test-key");
std::env::set_var("EDGEE_BASE_URL", "https://test.example.com");
}
let config = EdgeeConfig::from_env().unwrap();
assert_eq!(config.api_key, "test-key");
assert_eq!(config.base_url, "https://test.example.com");
unsafe {
std::env::remove_var("EDGEE_API_KEY");
std::env::remove_var("EDGEE_BASE_URL");
}
}
#[test]
fn test_config_builder() {
let config = EdgeeConfig::new("my-key").with_base_url("https://custom.example.com");
assert_eq!(config.api_key, "my-key");
assert_eq!(config.base_url, "https://custom.example.com");
}
#[test]
fn test_message_constructors() {
let msg = Message::user("hello");
assert_eq!(msg.role, Role::User);
assert_eq!(msg.content.as_deref(), Some("hello"));
let msg = Message::system("You are helpful");
assert_eq!(msg.role, Role::System);
let msg = Message::developer("You are helpful");
assert_eq!(msg.role, Role::Developer);
let msg = Message::assistant("Hi there");
assert_eq!(msg.role, Role::Assistant);
let msg = Message::tool("call-123", "result");
assert_eq!(msg.role, Role::Tool);
assert_eq!(msg.tool_call_id.as_deref(), Some("call-123"));
}
}