use base64;
use futures_util::{SinkExt, StreamExt};
use orign::models::{
ChatRequest, ChatResponse, ContentItem, ImageUrlContent, MessageContent, MessageItem,
ModelReadyResponse, Prompt, TokenResponse,
};
use base64::prelude::*;
use http::header::HeaderValue;
use orign::config::GlobalConfig;
use serde::Deserialize;
use std::error::Error;
use tokio::io::AsyncWriteExt;
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::protocol::Message;
#[derive(Deserialize)]
#[serde(untagged)]
enum WebSocketResponse {
ModelReady(ModelReadyResponse),
ChatResponse(ChatResponse),
TokenResponse(TokenResponse),
}
pub async fn execute(
model: String,
message: String,
images: Vec<String>,
adapter: Option<String>,
provider: Option<String>,
) -> Result<(), Box<dyn Error>> {
let config = GlobalConfig::read()?;
let server = config.server.unwrap();
let api_key = config.api_key.as_deref().ok_or("API key not set")?;
let bearer_token = format!("Bearer {}", api_key);
let mut websocket_url = if server.starts_with("https://") {
format!(
"wss://{}/v1/chat/stream?model={}",
server.strip_prefix("https://").unwrap(),
urlencoding::encode(&model)
)
} else if server.starts_with("http://") {
format!(
"ws://{}/v1/chat/stream?model={}",
server.strip_prefix("http://").unwrap(),
urlencoding::encode(&model)
)
} else {
return Err("Invalid server URL scheme. Must be http:// or https://".into());
};
if let Some(provider_name) = provider {
websocket_url.push_str(&format!("&provider={}", provider_name));
}
let mut request = websocket_url.as_str().into_client_request()?;
request
.headers_mut()
.append("Authorization", HeaderValue::from_str(&bearer_token)?);
let (ws_stream, _response) = connect_async(request).await?;
let (mut write, mut read) = ws_stream.split();
let mut content_items = vec![ContentItem {
content_type: "text".to_string(),
text: Some(message),
image_url: None,
}];
for image_path_or_url in images {
let image_data: Vec<u8>;
if image_path_or_url.starts_with("http://") || image_path_or_url.starts_with("https://") {
let response = reqwest::get(&image_path_or_url).await?;
let bytes = response.bytes().await?;
image_data = bytes.to_vec();
} else {
image_data = tokio::fs::read(&image_path_or_url).await?;
}
let mime_type = infer::get(&image_data)
.map(|t| t.mime_type())
.unwrap_or("application/octet-stream");
let encoded_image = BASE64_STANDARD.encode(&image_data);
let data_url = format!("data:{};base64,{}", mime_type, encoded_image);
content_items.push(ContentItem {
content_type: "image_url".to_string(),
text: None,
image_url: Some(ImageUrlContent { url: data_url }),
});
}
let messages = vec![MessageItem {
role: "user".to_string(),
content: MessageContent::Item(content_items),
}];
let payload = ChatRequest {
model: Some(model.clone()),
prompt: Some(Prompt { messages }),
adapter: adapter.clone(),
stream: true,
..Default::default()
};
let payload_text = serde_json::to_string(&payload)?;
write.send(Message::Text(payload_text.clone())).await?;
let stdout = tokio::io::stdout();
let mut accumulated_text = String::new();
let mut handle = stdout;
while let Some(msg) = read.next().await {
match msg {
Ok(Message::Text(text)) => {
let json_msg: WebSocketResponse = match serde_json::from_str(&text) {
Ok(msg) => msg,
Err(e) => {
eprintln!("Failed to parse message: {}", e);
eprintln!("Raw message: {}", text);
continue;
}
};
match json_msg {
WebSocketResponse::ModelReady(ready_response) => {
if !ready_response.ready {
if let Some(error) = ready_response.error {
return Err(format!("Model not ready: {}", error).into());
}
return Err("Model not ready".into());
}
}
WebSocketResponse::ChatResponse(chat_response) => {
if let Some(choice) = chat_response.choices.get(0) {
accumulated_text.push_str(&choice.text);
handle.write_all(choice.text.as_bytes()).await?;
handle.flush().await?;
if choice.finish_reason.is_some() {
break;
}
}
}
WebSocketResponse::TokenResponse(token_response) => {
for token in token_response.tokens {
accumulated_text.push_str(&token);
handle.write_all(token.as_bytes()).await?;
handle.flush().await?;
}
}
}
}
Ok(Message::Binary(bin)) => {
println!("Received binary message of {} bytes", bin.len());
tokio::io::stdout().write_all(&bin).await.unwrap();
}
Ok(Message::Close(frame)) => {
println!("Connection closed by server");
if let Some(cf) = frame {
println!("Close frame: code: {}, reason: {}", cf.code, cf.reason);
}
break;
}
Ok(Message::Ping(_)) => {
println!("Received ping");
}
Ok(Message::Pong(_)) => {
println!("Received pong");
}
Err(e) => {
eprintln!("WebSocket error: {}", e);
eprintln!("Error details: {:?}", e);
break;
}
_ => {}
}
}
handle.write_all(b"\n").await?;
handle.flush().await?;
Ok(())
}