orign 0.2.3

A globally distributed container orchestrator
Documentation
use base64;
use futures_util::{SinkExt, StreamExt};
use orign::models::{
    ChatRequest, ChatResponse, ContentItem, ImageUrlContent, MessageContent, MessageItem,
    ModelReadyResponse, Prompt, TokenResponse,
};

use base64::prelude::*;
use http::header::HeaderValue;
use orign::config::GlobalConfig;
use serde::Deserialize;
use std::error::Error;
use tokio::io::AsyncWriteExt;
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::protocol::Message;

#[derive(Deserialize)]
#[serde(untagged)]
enum WebSocketResponse {
    ModelReady(ModelReadyResponse),
    ChatResponse(ChatResponse),
    TokenResponse(TokenResponse),
}

pub async fn execute(
    model: String,
    message: String,
    images: Vec<String>,
    adapter: Option<String>,
    provider: Option<String>,
) -> Result<(), Box<dyn Error>> {
    let config = GlobalConfig::read()?;
    let server = config.server.unwrap();

    let api_key = config.api_key.as_deref().ok_or("API key not set")?;
    let bearer_token = format!("Bearer {}", api_key);

    // Prepare the WebSocket URL
    let mut websocket_url = if server.starts_with("https://") {
        format!(
            "wss://{}/v1/chat/stream?model={}",
            server.strip_prefix("https://").unwrap(),
            urlencoding::encode(&model)
        )
    } else if server.starts_with("http://") {
        format!(
            "ws://{}/v1/chat/stream?model={}",
            server.strip_prefix("http://").unwrap(),
            urlencoding::encode(&model)
        )
    } else {
        return Err("Invalid server URL scheme. Must be http:// or https://".into());
    };

    if let Some(provider_name) = provider {
        websocket_url.push_str(&format!("&provider={}", provider_name));
    }

    let mut request = websocket_url.as_str().into_client_request()?;
    request
        .headers_mut()
        .append("Authorization", HeaderValue::from_str(&bearer_token)?);

    // Connect to the WebSocket server
    let (ws_stream, _response) = connect_async(request).await?;

    let (mut write, mut read) = ws_stream.split();

    // Prepare the content items
    let mut content_items = vec![ContentItem {
        content_type: "text".to_string(),
        text: Some(message),
        image_url: None,
    }];

    // Handle image if supplied
    for image_path_or_url in images {
        let image_data: Vec<u8>;

        // Determine if it's a URL
        if image_path_or_url.starts_with("http://") || image_path_or_url.starts_with("https://") {
            // It's a URL, download the image
            let response = reqwest::get(&image_path_or_url).await?;
            let bytes = response.bytes().await?;
            image_data = bytes.to_vec();
        } else {
            // It's a local file path, read the image
            image_data = tokio::fs::read(&image_path_or_url).await?;
        }

        // Detect MIME type of the image
        let mime_type = infer::get(&image_data)
            .map(|t| t.mime_type())
            .unwrap_or("application/octet-stream");

        // Base64 encode the image data and create a data URL
        let encoded_image = BASE64_STANDARD.encode(&image_data);
        let data_url = format!("data:{};base64,{}", mime_type, encoded_image);

        // Add the image to the content items
        content_items.push(ContentItem {
            content_type: "image_url".to_string(),
            text: None,
            image_url: Some(ImageUrlContent { url: data_url }),
        });
    }

    // Prepare the messages vector
    let messages = vec![MessageItem {
        role: "user".to_string(),
        content: MessageContent::Item(content_items),
    }];

    let payload = ChatRequest {
        model: Some(model.clone()),
        prompt: Some(Prompt { messages }),
        adapter: adapter.clone(),
        stream: true,
        ..Default::default()
    };
    let payload_text = serde_json::to_string(&payload)?;

    // Send the initial chat message
    write.send(Message::Text(payload_text.clone())).await?;
    let stdout = tokio::io::stdout();
    // println!("Initial message sent: {}", payload_text);

    let mut accumulated_text = String::new();
    let mut handle = stdout;

    // Handle messages from WebSocket
    while let Some(msg) = read.next().await {
        match msg {
            Ok(Message::Text(text)) => {
                let json_msg: WebSocketResponse = match serde_json::from_str(&text) {
                    Ok(msg) => msg,
                    Err(e) => {
                        eprintln!("Failed to parse message: {}", e);
                        eprintln!("Raw message: {}", text);
                        continue;
                    }
                };

                match json_msg {
                    WebSocketResponse::ModelReady(ready_response) => {
                        if !ready_response.ready {
                            if let Some(error) = ready_response.error {
                                return Err(format!("Model not ready: {}", error).into());
                            }
                            return Err("Model not ready".into());
                        }
                        // Model is ready, continue with the chat
                    }
                    WebSocketResponse::ChatResponse(chat_response) => {
                        if let Some(choice) = chat_response.choices.get(0) {
                            accumulated_text.push_str(&choice.text);
                            handle.write_all(choice.text.as_bytes()).await?;
                            handle.flush().await?;

                            if choice.finish_reason.is_some() {
                                break;
                            }
                        }
                    }
                    WebSocketResponse::TokenResponse(token_response) => {
                        for token in token_response.tokens {
                            accumulated_text.push_str(&token);
                            handle.write_all(token.as_bytes()).await?;
                            handle.flush().await?;
                        }
                    }
                }
            }
            Ok(Message::Binary(bin)) => {
                println!("Received binary message of {} bytes", bin.len());
                tokio::io::stdout().write_all(&bin).await.unwrap();
            }
            Ok(Message::Close(frame)) => {
                println!("Connection closed by server");
                if let Some(cf) = frame {
                    println!("Close frame: code: {}, reason: {}", cf.code, cf.reason);
                }
                break;
            }
            Ok(Message::Ping(_)) => {
                println!("Received ping");
            }
            Ok(Message::Pong(_)) => {
                println!("Received pong");
            }
            Err(e) => {
                eprintln!("WebSocket error: {}", e);
                eprintln!("Error details: {:?}", e);
                break;
            }
            _ => {}
        }
    }
    // Write a newline character at the end
    handle.write_all(b"\n").await?;
    handle.flush().await?;
    Ok(())
}