myst_client/
task.rs

1use async_openai::types::{CreateChatCompletionRequest, CreateChatCompletionStreamResponse};
2use futures::Stream;
3use serde::{Deserialize, Serialize};
4use std::{net::SocketAddr, pin::Pin};
5
6use crate::transport::{NodeId, RpcClient};
7
8#[derive(Serialize, Deserialize, Debug, Clone)]
9pub struct TaskContext {
10    pub node_id: NodeId,
11    pub proxy_addr: Option<SocketAddr>,
12    pub model_ctx: ModelContext,
13}
14
15#[derive(Serialize, Deserialize, Debug, Clone)]
16pub struct ModelContext {
17    /// The base URL for the API associated with this model.
18    pub api_base_url: Option<String>,
19
20    /// The name of the environment variable containing the credentials for this model.
21    pub credentials_env_name: Option<String>,
22}
23
24#[derive(Debug, Deserialize)]
25pub enum Task {
26    Text {
27        node_id: Option<NodeId>,
28        request: CreateChatCompletionRequest,
29    },
30    Image {
31        node_id: Option<NodeId>,
32        prompt: String,
33    },
34}
35
36#[derive(Debug, Serialize)]
37pub enum TaskResponse {
38    Text {
39        node_id: Option<NodeId>,
40    },
41    Image {
42        node_id: Option<NodeId>,
43        data: String,
44    },
45}
46
47#[derive(thiserror::Error, Debug)]
48pub enum TaskError {
49    #[error("text task failed: {0}")]
50    Text(String),
51    #[error("image task failed: {0}")]
52    Image(String),
53    #[error("Network error: {0}")]
54    Network(String),
55}
56
57#[derive(Default, Clone)]
58pub struct Runner {
59    // add configurations for caching, workflows, etc.
60}
61
62impl Runner {
63    pub async fn start_text_stream(
64        &self,
65        ctx: TaskContext,
66        request: CreateChatCompletionRequest,
67    ) -> Result<
68        (
69            TaskResponse,
70            Pin<
71                Box<
72                    dyn Stream<Item = Result<CreateChatCompletionStreamResponse, anyhow::Error>>
73                        + Send,
74                >,
75            >,
76        ),
77        TaskError,
78    > {
79        eprintln!("starting text stream for ctx: {:?}", ctx);
80
81        let mut client = RpcClient::new(ctx.node_id.clone(), ctx.proxy_addr)
82            .await
83            .map_err(|e| TaskError::Network(e.to_string()))?;
84
85        let stream = client
86            .compute_text(request, ctx.clone())
87            .await
88            .map_err(|e| TaskError::Text(e.to_string()))?;
89
90        Ok((
91            TaskResponse::Text {
92                node_id: Some(ctx.node_id),
93            },
94            stream,
95        ))
96    }
97
98    pub async fn run(&self, ctx: TaskContext, task: Task) -> Result<TaskResponse, TaskError> {
99        eprintln!("running task with context: {:?}", ctx);
100        let mut client = RpcClient::new(ctx.node_id.clone(), ctx.proxy_addr)
101            .await
102            .map_err(|e| TaskError::Network(e.to_string()))?;
103
104        match task {
105            Task::Image { node_id, prompt } => {
106                let node_id = node_id.clone();
107                let data = client
108                    .generate_image(prompt, ctx)
109                    .await
110                    .map_err(|e| TaskError::Image(e.to_string()))?;
111
112                Ok(TaskResponse::Image { node_id, data })
113            }
114            Task::Text { .. } => Err(TaskError::Text(
115                "Use start_text_stream for text completion tasks".into(),
116            )),
117        }
118    }
119}