rig_onchain_kit/http/
routes.rs

1use super::middleware::verify_auth;
2use super::state::AppState;
3use crate::common::spawn_with_signer;
4use crate::cross_chain::agent::create_cross_chain_agent;
5use crate::reasoning_loop::LoopResponse;
6use crate::reasoning_loop::ReasoningLoop;
7use crate::signer::privy::PrivySigner;
8use crate::signer::TransactionSigner;
9use actix_web::{
10    get, post, web, Error, HttpRequest, HttpResponse, Responder,
11};
12use actix_web_lab::sse;
13use anyhow::Result;
14use rig::completion::Message;
15use rig::message::UserContent;
16use rig::OneOrMany;
17use serde::{Deserialize, Serialize};
18use serde_json::json;
19use std::sync::Arc;
20use std::time::Duration;
21
22#[derive(Deserialize)]
23pub struct ChatRequest {
24    prompt: String,
25    #[serde(deserialize_with = "deserialize_messages")]
26    chat_history: Vec<Message>,
27    #[serde(default)]
28    chain: Option<String>,
29    #[serde(default)]
30    preamble: Option<String>,
31}
32
33#[derive(Serialize)]
34#[serde(tag = "type", content = "content")]
35pub enum StreamResponse {
36    Message(String),
37    ToolCall { name: String, result: String },
38    Error(String),
39}
40
41#[derive(Serialize)]
42pub enum ServerError {
43    WalletError,
44    PrivyError,
45    ChainNotSupported,
46}
47
48#[post("/stream")]
49async fn stream(
50    req: HttpRequest,
51    state: web::Data<AppState>,
52    request: web::Json<ChatRequest>,
53) -> impl Responder {
54    let user_session = match verify_auth(&req).await {
55        Ok(s) => s,
56        Err(e) => {
57            let (tx, rx) = tokio::sync::mpsc::channel::<sse::Event>(1);
58            let error_event = sse::Event::Data(sse::Data::new(
59                serde_json::to_string(&StreamResponse::Error(format!(
60                    "Error: unauthorized: {}",
61                    e
62                )))
63                .unwrap(),
64            ));
65            let _ = tx.send(error_event).await;
66            return sse::Sse::from_infallible_receiver(rx);
67        }
68    };
69
70    let (tx, rx) = tokio::sync::mpsc::channel::<sse::Event>(32);
71
72    let preamble = request.preamble.clone();
73
74    // Select the appropriate agent based on the chain parameter and preamble
75    let agent = match request.chain.as_deref() {
76        #[cfg(feature = "solana")]
77        Some("solana") => match create_solana_agent(preamble).await {
78            Ok(agent) => Arc::new(agent),
79            Err(e) => {
80                let error_event = sse::Event::Data(sse::Data::new(
81                    serde_json::to_string(&StreamResponse::Error(format!(
82                        "Failed to create Solana agent: {}",
83                        e
84                    )))
85                    .unwrap(),
86                ));
87                let _ = tx.send(error_event).await;
88                return sse::Sse::from_infallible_receiver(rx);
89            }
90        },
91        #[cfg(feature = "evm")]
92        Some("evm") => match create_evm_agent(preamble).await {
93            Ok(agent) => Arc::new(agent),
94            Err(e) => {
95                let error_event = sse::Event::Data(sse::Data::new(
96                    serde_json::to_string(&StreamResponse::Error(format!(
97                        "Failed to create EVM agent: {}",
98                        e
99                    )))
100                    .unwrap(),
101                ));
102                let _ = tx.send(error_event).await;
103                return sse::Sse::from_infallible_receiver(rx);
104            }
105        },
106        Some("omni") => match create_cross_chain_agent(preamble).await {
107            Ok(agent) => Arc::new(agent),
108            Err(e) => {
109                let error_event = sse::Event::Data(sse::Data::new(
110                    serde_json::to_string(&StreamResponse::Error(format!(
111                        "Failed to create cross-chain agent: {}",
112                        e
113                    )))
114                    .unwrap(),
115                ));
116                let _ = tx.send(error_event).await;
117                return sse::Sse::from_infallible_receiver(rx);
118            }
119        },
120        Some(chain) => {
121            let error_event = sse::Event::Data(sse::Data::new(
122                serde_json::to_string(&StreamResponse::Error(format!(
123                    "Unsupported chain: {}",
124                    chain
125                )))
126                .unwrap(),
127            ));
128            let _ = tx.send(error_event).await;
129            return sse::Sse::from_infallible_receiver(rx);
130        }
131        None => {
132            let error_event = sse::Event::Data(sse::Data::new(
133                serde_json::to_string(&StreamResponse::Error(
134                    "Chain parameter is required".to_string(),
135                ))
136                .unwrap(),
137            ));
138            let _ = tx.send(error_event).await;
139            return sse::Sse::from_infallible_receiver(rx);
140        }
141    };
142
143    let prompt = request.prompt.clone();
144    let messages = request.chat_history.clone();
145    println!("prompt: {}", prompt);
146    println!("messages: {:?}", messages);
147
148    let signer: Arc<dyn TransactionSigner> =
149        Arc::new(PrivySigner::new(state.privy.clone(), user_session.clone()));
150
151    spawn_with_signer(signer, || async move {
152        let reasoning_loop = ReasoningLoop::new(agent).with_stdout(false);
153
154        let mut initial_messages = messages;
155        initial_messages.push(Message::User {
156            content: OneOrMany::one(UserContent::text(prompt)),
157        });
158
159        // Create a channel for the reasoning loop to send responses
160        let (internal_tx, mut internal_rx) = tokio::sync::mpsc::channel(32);
161
162        // Create a separate task to handle sending responses
163        let tx_clone = tx.clone();
164        let send_task = tokio::spawn(async move {
165            while let Some(response) = internal_rx.recv().await {
166                let stream_response = match response {
167                    LoopResponse::Message(text) => {
168                        StreamResponse::Message(text)
169                    }
170                    LoopResponse::ToolCall { name, result } => {
171                        StreamResponse::ToolCall { name, result }
172                    }
173                };
174
175                if tx_clone
176                    .send(sse::Event::Data(sse::Data::new(
177                        serde_json::to_string(&stream_response).unwrap(),
178                    )))
179                    .await
180                    .is_err()
181                {
182                    break;
183                }
184            }
185        });
186
187        // Run the reasoning loop in the current task (with signer context)
188        let loop_result = reasoning_loop
189            .stream(initial_messages, Some(internal_tx))
190            .await;
191
192        // Wait for the send task to complete
193        let _ = send_task.await;
194
195        // Check if the reasoning loop completed successfully
196        if let Err(e) = loop_result {
197            let _ = tx
198                .send(sse::Event::Data(sse::Data::new(
199                    serde_json::to_string(&StreamResponse::Error(
200                        e.to_string(),
201                    ))
202                    .unwrap(),
203                )))
204                .await;
205        }
206
207        Ok(())
208    })
209    .await;
210
211    sse::Sse::from_infallible_receiver(rx)
212        .with_keep_alive(Duration::from_secs(15))
213        .with_retry_duration(Duration::from_secs(10))
214}
215
216#[get("/healthz")]
217async fn healthz() -> Result<HttpResponse, Error> {
218    Ok(HttpResponse::Ok().json(json!({
219        "status": "ok",
220        "timestamp": chrono::Utc::now().to_rfc3339()
221    })))
222}
223
224#[get("/auth")]
225async fn auth(req: HttpRequest) -> Result<HttpResponse, Error> {
226    let user_session = match verify_auth(&req).await {
227        Ok(session) => session,
228        Err(e) => {
229            return Ok(HttpResponse::Unauthorized()
230                .json(json!({ "error": e.to_string() })))
231        }
232    };
233
234    Ok(HttpResponse::Ok().json(json!({
235        "status": "ok",
236        "wallet_address": user_session.wallet_address,
237    })))
238}
239
240fn deserialize_messages<'de, D>(
241    deserializer: D,
242) -> Result<Vec<Message>, D::Error>
243where
244    D: serde::Deserializer<'de>,
245{
246    #[derive(Deserialize)]
247    struct RawMessage {
248        role: String,
249        content: serde_json::Value,
250    }
251
252    let raw_messages: Vec<RawMessage> = Vec::deserialize(deserializer)?;
253
254    raw_messages
255        .into_iter()
256        .map(|raw| {
257            let content = match raw.role.as_str() {
258                "user" => {
259                    let content = match raw.content {
260                        serde_json::Value::String(s) => {
261                            OneOrMany::one(UserContent::Text(s.into()))
262                        }
263                        _ => {
264                            return Err(serde::de::Error::custom(
265                                "Invalid user content format",
266                            ))
267                        }
268                    };
269                    Message::User { content }
270                }
271                "assistant" => {
272                    let content = match raw.content {
273                        serde_json::Value::String(s) => {
274                            OneOrMany::one(s.into())
275                        }
276                        _ => {
277                            return Err(serde::de::Error::custom(
278                                "Invalid assistant content format",
279                            ))
280                        }
281                    };
282                    Message::Assistant { content }
283                }
284                _ => return Err(serde::de::Error::custom("Invalid role")),
285            };
286            Ok(content)
287        })
288        .collect()
289}