rig_onchain_kit/http/
routes.rs1use super::middleware::verify_auth;
2use super::state::AppState;
3use crate::common::spawn_with_signer;
4use crate::cross_chain::agent::create_cross_chain_agent;
5use crate::reasoning_loop::LoopResponse;
6use crate::reasoning_loop::ReasoningLoop;
7use crate::signer::privy::PrivySigner;
8use crate::signer::TransactionSigner;
9use actix_web::{
10 get, post, web, Error, HttpRequest, HttpResponse, Responder,
11};
12use actix_web_lab::sse;
13use anyhow::Result;
14use rig::completion::Message;
15use rig::message::UserContent;
16use rig::OneOrMany;
17use serde::{Deserialize, Serialize};
18use serde_json::json;
19use std::sync::Arc;
20use std::time::Duration;
21
22#[derive(Deserialize)]
23pub struct ChatRequest {
24 prompt: String,
25 #[serde(deserialize_with = "deserialize_messages")]
26 chat_history: Vec<Message>,
27 #[serde(default)]
28 chain: Option<String>,
29 #[serde(default)]
30 preamble: Option<String>,
31}
32
33#[derive(Serialize)]
34#[serde(tag = "type", content = "content")]
35pub enum StreamResponse {
36 Message(String),
37 ToolCall { name: String, result: String },
38 Error(String),
39}
40
41#[derive(Serialize)]
42pub enum ServerError {
43 WalletError,
44 PrivyError,
45 ChainNotSupported,
46}
47
48#[post("/stream")]
49async fn stream(
50 req: HttpRequest,
51 state: web::Data<AppState>,
52 request: web::Json<ChatRequest>,
53) -> impl Responder {
54 let user_session = match verify_auth(&req).await {
55 Ok(s) => s,
56 Err(e) => {
57 let (tx, rx) = tokio::sync::mpsc::channel::<sse::Event>(1);
58 let error_event = sse::Event::Data(sse::Data::new(
59 serde_json::to_string(&StreamResponse::Error(format!(
60 "Error: unauthorized: {}",
61 e
62 )))
63 .unwrap(),
64 ));
65 let _ = tx.send(error_event).await;
66 return sse::Sse::from_infallible_receiver(rx);
67 }
68 };
69
70 let (tx, rx) = tokio::sync::mpsc::channel::<sse::Event>(32);
71
72 let preamble = request.preamble.clone();
73
74 let agent = match request.chain.as_deref() {
76 #[cfg(feature = "solana")]
77 Some("solana") => match create_solana_agent(preamble).await {
78 Ok(agent) => Arc::new(agent),
79 Err(e) => {
80 let error_event = sse::Event::Data(sse::Data::new(
81 serde_json::to_string(&StreamResponse::Error(format!(
82 "Failed to create Solana agent: {}",
83 e
84 )))
85 .unwrap(),
86 ));
87 let _ = tx.send(error_event).await;
88 return sse::Sse::from_infallible_receiver(rx);
89 }
90 },
91 #[cfg(feature = "evm")]
92 Some("evm") => match create_evm_agent(preamble).await {
93 Ok(agent) => Arc::new(agent),
94 Err(e) => {
95 let error_event = sse::Event::Data(sse::Data::new(
96 serde_json::to_string(&StreamResponse::Error(format!(
97 "Failed to create EVM agent: {}",
98 e
99 )))
100 .unwrap(),
101 ));
102 let _ = tx.send(error_event).await;
103 return sse::Sse::from_infallible_receiver(rx);
104 }
105 },
106 Some("omni") => match create_cross_chain_agent(preamble).await {
107 Ok(agent) => Arc::new(agent),
108 Err(e) => {
109 let error_event = sse::Event::Data(sse::Data::new(
110 serde_json::to_string(&StreamResponse::Error(format!(
111 "Failed to create cross-chain agent: {}",
112 e
113 )))
114 .unwrap(),
115 ));
116 let _ = tx.send(error_event).await;
117 return sse::Sse::from_infallible_receiver(rx);
118 }
119 },
120 Some(chain) => {
121 let error_event = sse::Event::Data(sse::Data::new(
122 serde_json::to_string(&StreamResponse::Error(format!(
123 "Unsupported chain: {}",
124 chain
125 )))
126 .unwrap(),
127 ));
128 let _ = tx.send(error_event).await;
129 return sse::Sse::from_infallible_receiver(rx);
130 }
131 None => {
132 let error_event = sse::Event::Data(sse::Data::new(
133 serde_json::to_string(&StreamResponse::Error(
134 "Chain parameter is required".to_string(),
135 ))
136 .unwrap(),
137 ));
138 let _ = tx.send(error_event).await;
139 return sse::Sse::from_infallible_receiver(rx);
140 }
141 };
142
143 let prompt = request.prompt.clone();
144 let messages = request.chat_history.clone();
145 println!("prompt: {}", prompt);
146 println!("messages: {:?}", messages);
147
148 let signer: Arc<dyn TransactionSigner> =
149 Arc::new(PrivySigner::new(state.privy.clone(), user_session.clone()));
150
151 spawn_with_signer(signer, || async move {
152 let reasoning_loop = ReasoningLoop::new(agent).with_stdout(false);
153
154 let mut initial_messages = messages;
155 initial_messages.push(Message::User {
156 content: OneOrMany::one(UserContent::text(prompt)),
157 });
158
159 let (internal_tx, mut internal_rx) = tokio::sync::mpsc::channel(32);
161
162 let tx_clone = tx.clone();
164 let send_task = tokio::spawn(async move {
165 while let Some(response) = internal_rx.recv().await {
166 let stream_response = match response {
167 LoopResponse::Message(text) => {
168 StreamResponse::Message(text)
169 }
170 LoopResponse::ToolCall { name, result } => {
171 StreamResponse::ToolCall { name, result }
172 }
173 };
174
175 if tx_clone
176 .send(sse::Event::Data(sse::Data::new(
177 serde_json::to_string(&stream_response).unwrap(),
178 )))
179 .await
180 .is_err()
181 {
182 break;
183 }
184 }
185 });
186
187 let loop_result = reasoning_loop
189 .stream(initial_messages, Some(internal_tx))
190 .await;
191
192 let _ = send_task.await;
194
195 if let Err(e) = loop_result {
197 let _ = tx
198 .send(sse::Event::Data(sse::Data::new(
199 serde_json::to_string(&StreamResponse::Error(
200 e.to_string(),
201 ))
202 .unwrap(),
203 )))
204 .await;
205 }
206
207 Ok(())
208 })
209 .await;
210
211 sse::Sse::from_infallible_receiver(rx)
212 .with_keep_alive(Duration::from_secs(15))
213 .with_retry_duration(Duration::from_secs(10))
214}
215
216#[get("/healthz")]
217async fn healthz() -> Result<HttpResponse, Error> {
218 Ok(HttpResponse::Ok().json(json!({
219 "status": "ok",
220 "timestamp": chrono::Utc::now().to_rfc3339()
221 })))
222}
223
224#[get("/auth")]
225async fn auth(req: HttpRequest) -> Result<HttpResponse, Error> {
226 let user_session = match verify_auth(&req).await {
227 Ok(session) => session,
228 Err(e) => {
229 return Ok(HttpResponse::Unauthorized()
230 .json(json!({ "error": e.to_string() })))
231 }
232 };
233
234 Ok(HttpResponse::Ok().json(json!({
235 "status": "ok",
236 "wallet_address": user_session.wallet_address,
237 })))
238}
239
240fn deserialize_messages<'de, D>(
241 deserializer: D,
242) -> Result<Vec<Message>, D::Error>
243where
244 D: serde::Deserializer<'de>,
245{
246 #[derive(Deserialize)]
247 struct RawMessage {
248 role: String,
249 content: serde_json::Value,
250 }
251
252 let raw_messages: Vec<RawMessage> = Vec::deserialize(deserializer)?;
253
254 raw_messages
255 .into_iter()
256 .map(|raw| {
257 let content = match raw.role.as_str() {
258 "user" => {
259 let content = match raw.content {
260 serde_json::Value::String(s) => {
261 OneOrMany::one(UserContent::Text(s.into()))
262 }
263 _ => {
264 return Err(serde::de::Error::custom(
265 "Invalid user content format",
266 ))
267 }
268 };
269 Message::User { content }
270 }
271 "assistant" => {
272 let content = match raw.content {
273 serde_json::Value::String(s) => {
274 OneOrMany::one(s.into())
275 }
276 _ => {
277 return Err(serde::de::Error::custom(
278 "Invalid assistant content format",
279 ))
280 }
281 };
282 Message::Assistant { content }
283 }
284 _ => return Err(serde::de::Error::custom("Invalid role")),
285 };
286 Ok(content)
287 })
288 .collect()
289}