Skip to main content

potato/utils/
ai.rs

1pub struct OpenAISender {
2    id: String,
3    object: String,
4    model: String,
5    role: String,
6    tx: tokio::sync::mpsc::Sender<Vec<u8>>,
7}
8
9impl OpenAISender {
10    pub async fn new(
11        id: impl Into<String>,
12        object: impl Into<String>,
13        model: impl Into<String>,
14        role: impl Into<String>,
15        buffer_size: usize,
16    ) -> anyhow::Result<(Self, crate::HttpResponse)> {
17        let (tx, rx) = tokio::sync::mpsc::channel(buffer_size);
18        let obj = Self {
19            id: id.into(),
20            object: object.into(),
21            model: model.into(),
22            role: role.into(),
23            tx,
24        };
25
26        let root = serde_json::to_string(&serde_json::json!({
27            "id": obj.id,
28            "object": obj.object,
29            "created": chrono::Utc::now().timestamp(),
30            "model": obj.model,
31            "choices": [{
32                "index": 0,
33                "delta": {
34                    "role": obj.role,
35                },
36                "finish_reason": null,
37            }]
38        }))?;
39        let payload = format!("data: {root}\n\n");
40        obj.tx.send(payload.into_bytes()).await?;
41
42        Ok((obj, crate::HttpResponse::sse(rx)))
43    }
44
45    pub async fn send(&self, message: impl Into<String>) -> anyhow::Result<()> {
46        let root = serde_json::to_string(&serde_json::json!({
47            "id": self.id,
48            "object": self.object,
49            "created": chrono::Utc::now().timestamp(),
50            "model": self.model,
51            "choices": [{
52                "index": 0,
53                "delta": {
54                    "content": message.into(),
55                },
56                "finish_reason": null,
57            }]
58        }))?;
59        let payload = format!("data: {root}\n\n");
60        self.tx.send(payload.into_bytes()).await?;
61        Ok(())
62    }
63
64    pub async fn send_finish(&self, finish_reason: impl Into<String>) -> anyhow::Result<()> {
65        let root = serde_json::to_string(&serde_json::json!({
66            "id": self.id,
67            "object": self.object,
68            "created": chrono::Utc::now().timestamp(),
69            "model": self.model,
70            "choices": [{
71                "index": 0,
72                "delta": {},
73                "finish_reason": finish_reason.into(),
74            }]
75        }))?;
76        let payload = format!("data: {}\n\n", serde_json::to_string(&root)?);
77        self.tx.send(payload.into_bytes()).await?;
78        self.tx.send(b"data: [DONE]\n\n".to_vec()).await?;
79        Ok(())
80    }
81}
82
83pub struct ClaudeSender {
84    tx: tokio::sync::mpsc::Sender<Vec<u8>>,
85}
86
87impl ClaudeSender {
88    pub async fn new(
89        id: impl Into<String>,
90        model: impl Into<String>,
91        role: impl Into<String>,
92        buffer_size: usize,
93    ) -> anyhow::Result<(Self, crate::HttpResponse)> {
94        let (tx, rx) = tokio::sync::mpsc::channel(buffer_size);
95        let root = serde_json::to_string(&serde_json::json!({
96            "type": "message_start",
97            "message": {
98                "id": id.into(),
99                "type": "message",
100                "role": role.into(),
101                "model": model.into(),
102                "content": [],
103                "stop_reason": null,
104                "stop_sequence": null,
105                "usage": {
106                    "input_tokens": 0,
107                    "output_tokens": 0
108                }
109            }
110        }))?;
111        let payload = format!("event: message_start\ndata: {root}\n\n");
112        tx.send(payload.into_bytes()).await?;
113
114        let root = serde_json::to_string(&serde_json::json!({
115            "type": "content_block_start",
116            "index": 0,
117            "content_block": {
118                "type": "text",
119                "text": ""
120            }
121        }))?;
122        let payload = format!("event: content_block_start\ndata: {root}\n\n");
123        tx.send(payload.into_bytes()).await?;
124
125        Ok((Self { tx }, crate::HttpResponse::sse(rx)))
126    }
127
128    pub async fn send(&self, message: impl Into<String>) -> anyhow::Result<()> {
129        let root = serde_json::to_string(&serde_json::json!({
130            "type": "content_block_delta",
131            "index": 0,
132            "delta": {
133                "text": message.into()
134            }
135        }))?;
136        let payload = format!("event: content_block_delta\ndata: {root}\n\n");
137        self.tx.send(payload.into_bytes()).await?;
138        Ok(())
139    }
140
141    pub async fn send_finish(&self) -> anyhow::Result<()> {
142        let root = serde_json::to_string(&serde_json::json!({
143            "type": "content_block_stop",
144            "index": 0
145        }))?;
146        let payload = format!("event: content_block_stop\ndata: {root}\n\n");
147        self.tx.send(payload.into_bytes()).await?;
148
149        let root = serde_json::to_string(&serde_json::json!({
150            "type": "message_delta",
151            "delta": {
152                "stop_reason": "end_turn",
153                "stop_sequence": null
154            },
155            "usage": {
156                "output_tokens": 0
157            }
158        }))?;
159        let payload = format!("event: message_delta\ndata: {root}\n\n");
160        self.tx.send(payload.into_bytes()).await?;
161
162        let root = serde_json::to_string(&serde_json::json!({
163            "type": "message_stop"
164        }))?;
165        let payload = format!("event: message_stop\ndata: {root}\n\n");
166        self.tx.send(payload.into_bytes()).await?;
167
168        Ok(())
169    }
170}
171
172pub struct OllamaSender {
173    model: String,
174    tx: tokio::sync::mpsc::Sender<Vec<u8>>,
175}
176
177impl OllamaSender {
178    pub async fn new(
179        model: impl Into<String>,
180        buffer_size: usize,
181    ) -> anyhow::Result<(Self, crate::HttpResponse)> {
182        let (tx, rx) = tokio::sync::mpsc::channel(buffer_size);
183        let model = model.into();
184
185        let mut resp = crate::HttpResponse::sse(rx);
186        resp.add_header("Content-Type".into(), "application/x-ndjson".into());
187
188        Ok((Self { model, tx }, resp))
189    }
190
191    pub async fn send(&self, message: impl Into<String>) -> anyhow::Result<()> {
192        let root = serde_json::to_string(&serde_json::json!({
193            "model": self.model,
194            "created_at": chrono::Utc::now().to_rfc3339(),
195            "response": message.into(),
196            "done": false
197        }))?;
198        // Ollama 使用 NDJSON 格式(newline-delimited JSON)
199        let payload = format!("{root}\n");
200        self.tx.send(payload.into_bytes()).await?;
201        Ok(())
202    }
203
204    pub async fn send_finish(&self) -> anyhow::Result<()> {
205        let root = serde_json::to_string(&serde_json::json!({
206            "model": self.model,
207            "created_at": chrono::Utc::now().to_rfc3339(),
208            "response": "",
209            "done": true,
210            "done_reason": "stop"
211        }))?;
212        let payload = format!("{root}\n");
213        self.tx.send(payload.into_bytes()).await?;
214        Ok(())
215    }
216}