1use async_openai::types::{CreateChatCompletionRequest, CreateChatCompletionStreamResponse};
2use futures::Stream;
3use serde::{Deserialize, Serialize};
4use std::{net::SocketAddr, pin::Pin};
5
6use crate::transport::{NodeId, RpcClient};
7
8#[derive(Serialize, Deserialize, Debug, Clone)]
9pub struct TaskContext {
10 pub node_id: NodeId,
11 pub proxy_addr: Option<SocketAddr>,
12 pub model_ctx: ModelContext,
13}
14
15#[derive(Serialize, Deserialize, Debug, Clone)]
16pub struct ModelContext {
17 pub api_base_url: Option<String>,
19
20 pub credentials_env_name: Option<String>,
22}
23
24#[derive(Debug, Deserialize)]
25pub enum Task {
26 Text {
27 node_id: Option<NodeId>,
28 request: CreateChatCompletionRequest,
29 },
30 Image {
31 node_id: Option<NodeId>,
32 prompt: String,
33 },
34}
35
36#[derive(Debug, Serialize)]
37pub enum TaskResponse {
38 Text {
39 node_id: Option<NodeId>,
40 },
41 Image {
42 node_id: Option<NodeId>,
43 data: String,
44 },
45}
46
47#[derive(thiserror::Error, Debug)]
48pub enum TaskError {
49 #[error("text task failed: {0}")]
50 Text(String),
51 #[error("image task failed: {0}")]
52 Image(String),
53 #[error("Network error: {0}")]
54 Network(String),
55}
56
57#[derive(Default, Clone)]
58pub struct Runner {
59 }
61
62impl Runner {
63 pub async fn start_text_stream(
64 &self,
65 ctx: TaskContext,
66 request: CreateChatCompletionRequest,
67 ) -> Result<
68 (
69 TaskResponse,
70 Pin<
71 Box<
72 dyn Stream<Item = Result<CreateChatCompletionStreamResponse, anyhow::Error>>
73 + Send,
74 >,
75 >,
76 ),
77 TaskError,
78 > {
79 eprintln!("starting text stream for ctx: {:?}", ctx);
80
81 let mut client = RpcClient::new(ctx.node_id.clone(), ctx.proxy_addr)
82 .await
83 .map_err(|e| TaskError::Network(e.to_string()))?;
84
85 let stream = client
86 .compute_text(request, ctx.clone())
87 .await
88 .map_err(|e| TaskError::Text(e.to_string()))?;
89
90 Ok((
91 TaskResponse::Text {
92 node_id: Some(ctx.node_id),
93 },
94 stream,
95 ))
96 }
97
98 pub async fn run(&self, ctx: TaskContext, task: Task) -> Result<TaskResponse, TaskError> {
99 eprintln!("running task with context: {:?}", ctx);
100 let mut client = RpcClient::new(ctx.node_id.clone(), ctx.proxy_addr)
101 .await
102 .map_err(|e| TaskError::Network(e.to_string()))?;
103
104 match task {
105 Task::Image { node_id, prompt } => {
106 let node_id = node_id.clone();
107 let data = client
108 .generate_image(prompt, ctx)
109 .await
110 .map_err(|e| TaskError::Image(e.to_string()))?;
111
112 Ok(TaskResponse::Image { node_id, data })
113 }
114 Task::Text { .. } => Err(TaskError::Text(
115 "Use start_text_stream for text completion tasks".into(),
116 )),
117 }
118 }
119}