1use crate::types::{CCIPReadHandler, RPCCall, RPCResponse};
2use crate::CCIPReadMiddlewareError;
3use axum::{
4 extract::{Path, State},
5 http::StatusCode,
6 response::IntoResponse,
7 routing::{get, post},
8 Json, Router,
9};
10use ethers_core::abi::{Abi, Function};
11use ethers_core::utils::hex;
12use serde::Deserialize;
13use serde_json::{json, Value};
14use std::collections::HashMap;
15use std::net::{IpAddr, SocketAddr};
16use std::str::FromStr;
17use std::sync::Arc;
18use tower_http::trace::TraceLayer;
19use tracing::debug;
20
21type Handlers = HashMap<[u8; 4], (Function, Arc<dyn CCIPReadHandler + Sync + Send>)>;
22
23struct AppState {
24 handlers: Handlers,
25}
26
27#[derive(Clone)]
29pub struct Server {
30 ip_address: IpAddr,
31 port: u16,
32 handlers: Handlers,
33}
34
35#[derive(Deserialize)]
36pub struct CCIPReadMiddlewareRequest {
37 sender: String,
38 calldata: String,
39}
40
41impl Server {
42 pub fn new(ip_address: IpAddr, port: u16) -> Self {
48 Server {
49 ip_address,
50 port,
51 handlers: HashMap::new(),
52 }
53 }
54
55 pub fn add(
61 &mut self,
62 abi: Abi,
63 name: &str,
64 callback: Arc<dyn CCIPReadHandler + Sync + Send>,
65 ) -> Result<(), CCIPReadMiddlewareError> {
66 let function = abi.function(name)?.clone();
67 debug!(
68 "Added function with short sig: {:?}",
69 function.short_signature()
70 );
71 self.handlers
72 .insert(function.short_signature(), (function, callback));
73 Ok(())
74 }
75
76 pub async fn start(&self, router: Option<Router>) -> Result<(), CCIPReadMiddlewareError> {
81 let ccip_router = self.router();
82 let app: Router = if let Some(router) = router {
83 router.merge(ccip_router)
84 } else {
85 ccip_router
86 };
87
88 let bound_interface: SocketAddr = SocketAddr::new(self.ip_address, self.port);
89 let _ = axum::Server::bind(&bound_interface)
90 .serve(app.into_make_service())
91 .await;
92 Ok(())
93 }
94
95 fn router(&self) -> Router {
96 let shared_state = Arc::new(AppState {
97 handlers: self.handlers.clone(),
98 });
99 Router::new()
100 .route("/gateway/:sender/:calldata", get(gateway_get))
101 .route("/gateway", post(gateway_post))
102 .with_state(shared_state)
103 .layer(TraceLayer::new_for_http())
104 }
105}
106
107async fn gateway_get(
108 Path((sender, calldata)): Path<(String, String)>,
109 State(app_state): State<Arc<AppState>>,
110) -> Result<impl IntoResponse, StatusCode> {
111 let calldata = String::from(calldata.strip_suffix(".json").unwrap_or(calldata.as_str()));
112 debug!("Should handle sender={:?} calldata={}", sender, calldata);
113
114 if let Ok(calldata) = ethers_core::types::Bytes::from_str(&calldata.as_str()[2..]) {
115 let response = call(
116 RPCCall {
117 to: sender.clone(),
118 data: calldata,
119 },
120 app_state.handlers.clone(),
121 )
122 .await
123 .unwrap();
124
125 let body = response.body;
126 Ok((StatusCode::OK, Json(body)))
127 } else {
128 let error_message: Value = json!({
129 "message": "Unexpected error",
130 });
131 Ok((StatusCode::INTERNAL_SERVER_ERROR, Json(error_message)))
132 }
133}
134
135async fn gateway_post(
136 State(app_state): State<Arc<AppState>>,
137 Json(data): Json<CCIPReadMiddlewareRequest>,
138) -> Result<impl IntoResponse, StatusCode> {
139 let sender = data.sender;
140 let calldata = String::from(
141 data.calldata
142 .strip_suffix(".json")
143 .unwrap_or(data.calldata.as_str()),
144 );
145 debug!("Should handle sender={:?} calldata={}", sender, calldata);
146
147 if let Ok(calldata) = ethers_core::types::Bytes::from_str(&calldata.as_str()[2..]) {
148 let response = call(
149 RPCCall {
150 to: sender.clone(),
151 data: calldata,
152 },
153 app_state.handlers.clone(),
154 )
155 .await
156 .unwrap();
157
158 let body = response.body;
159 Ok((StatusCode::OK, Json(body)))
160 } else {
161 let error_message: Value = json!({
162 "message": "Unexpected error",
163 });
164 Ok((StatusCode::INTERNAL_SERVER_ERROR, Json(error_message)))
165 }
166}
167
168#[tracing::instrument(
169 name = "ccip_server"
170 skip_all
171)]
172async fn call(call: RPCCall, handlers: Handlers) -> Result<RPCResponse, CCIPReadMiddlewareError> {
173 debug!("Received call with {:?}", call);
174 let selector = &call.data[0..4];
175
176 let handler = if let Some(handler) = handlers.get(selector) {
178 handler
179 } else {
180 return Ok(RPCResponse {
181 status: 404,
182 body: json!({
183 "message": format!("No implementation for function with selector 0x{}", hex::encode(selector)),
184 }),
185 });
186 };
187
188 let args = handler.0.decode_input(&call.data[4..])?;
190
191 let callback = handler.1.clone();
192 if let Ok(tokens) = callback
193 .call(
194 args,
195 RPCCall {
196 to: call.to,
197 data: call.data,
198 },
199 )
200 .await
201 {
202 let encoded_data = ethers_core::abi::encode(&tokens);
203 let encoded_data = format!("0x{}", hex::encode(encoded_data));
204 debug!("Final encoded data: {}", encoded_data);
205
206 Ok(RPCResponse {
207 status: 200,
208 body: json!({
209 "data": encoded_data,
210 }),
211 })
212 } else {
213 Ok(RPCResponse {
214 status: 500,
215 body: json!({
216 "message": "Unexpected error",
217 }),
218 })
219 }
220}
221
222#[cfg(test)]
225mod tests {
226 use super::*;
227 use axum::{
228 body::Body,
229 http::{Request, StatusCode},
230 };
231 use ethers::abi::AbiParser;
232 use ethers::contract::BaseContract;
233 use serde_json::{json, Value};
234 use tower::ServiceExt; #[test]
237 fn it_parse_offchain_resolver_abi() {
238 let abi = AbiParser::default().parse_str(r#"[
239 function resolve(bytes memory name, bytes memory data) external view returns(bytes memory)
240 ]"#).unwrap();
241 let contract = BaseContract::from(abi);
242 println!("{:?}", contract.methods);
243 }
244
245 #[tokio::test]
246 async fn test_gateway_get_on_unknown_selector() {
247 let server = Server::new(IpAddr::V4("127.0.0.1".parse().unwrap()), 8080);
248 let router = server.router();
249
250 let response = router
251 .oneshot(Request::builder().uri("/gateway/0x8464135c8f25da09e49bc8782676a84730c318bc/0x9061b92300000000000000000000000000000000000000000000000000000000000000400000000000000000000000000000000000000000000000000000000000000080000000000000000000000000000000000000000000000000000000000000000a0474657374036574680000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000008459d1d43ceb4f647bea6caa36333c816d7b46fdcb05f9466ecacc140ea8c66faf15b3d9f100000000000000000000000000000000000000000000000000000000000000400000000000000000000000000000000000000000000000000000000000000005656d61696c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000.json").body(Body::empty()).unwrap())
252 .await
253 .unwrap();
254 assert_eq!(response.status(), StatusCode::OK);
255
256 let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
257 let body: Value = serde_json::from_slice(&body).unwrap();
258 assert_eq!(
259 body,
260 json!({ "message": "No implementation for function with selector 0x9061b923"})
261 );
262 }
263}