1use crate::error::Result;
27use crate::http_client::HttpClient;
28use async_trait::async_trait;
29use reqwest::header::HeaderMap;
30use serde_json::{Map, Value};
31use std::collections::BTreeMap;
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
35pub enum HttpMethod {
36 #[default]
38 Get,
39 Post,
41 Put,
43 Delete,
45}
46
47impl HttpMethod {
48 pub fn as_str(&self) -> &'static str {
50 match self {
51 Self::Get => "GET",
52 Self::Post => "POST",
53 Self::Put => "PUT",
54 Self::Delete => "DELETE",
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct SigningContext {
62 pub method: HttpMethod,
64 pub endpoint: String,
66 pub params: BTreeMap<String, String>,
68 pub body: Option<Value>,
70 pub timestamp: String,
72 pub signature: Option<String>,
74}
75
76impl SigningContext {
77 pub fn new(method: HttpMethod, endpoint: String) -> Self {
79 Self {
80 method,
81 endpoint,
82 params: BTreeMap::new(),
83 body: None,
84 timestamp: String::new(),
85 signature: None,
86 }
87 }
88}
89
90#[async_trait]
97pub trait SigningStrategy: Send + Sync {
98 async fn prepare_request(&self, ctx: &mut SigningContext) -> Result<()>;
106
107 fn add_auth_headers(&self, headers: &mut HeaderMap, ctx: &SigningContext);
112}
113
114pub trait HasHttpClient {
116 fn http_client(&self) -> &HttpClient;
118
119 fn base_url(&self) -> &'static str;
121}
122
123pub struct SignedRequestBuilder<'a, E, S>
130where
131 E: HasHttpClient,
132 S: SigningStrategy,
133{
134 exchange: &'a E,
135 strategy: S,
136 params: BTreeMap<String, String>,
137 body: Option<Value>,
138 endpoint: String,
139 method: HttpMethod,
140}
141
142impl<'a, E, S> SignedRequestBuilder<'a, E, S>
143where
144 E: HasHttpClient,
145 S: SigningStrategy,
146{
147 pub fn new(exchange: &'a E, strategy: S, endpoint: impl Into<String>) -> Self {
149 Self {
150 exchange,
151 strategy,
152 params: BTreeMap::new(),
153 body: None,
154 endpoint: endpoint.into(),
155 method: HttpMethod::default(),
156 }
157 }
158
159 pub fn method(mut self, method: HttpMethod) -> Self {
161 self.method = method;
162 self
163 }
164
165 pub fn param(mut self, key: impl Into<String>, value: &dyn ToString) -> Self {
167 self.params.insert(key.into(), value.to_string());
168 self
169 }
170
171 pub fn optional_param<T: ToString>(mut self, key: impl Into<String>, value: Option<T>) -> Self {
173 if let Some(v) = value {
174 self.params.insert(key.into(), v.to_string());
175 }
176 self
177 }
178
179 pub fn params(mut self, params: BTreeMap<String, String>) -> Self {
181 self.params.extend(params);
182 self
183 }
184
185 pub fn body(mut self, body: Value) -> Self {
187 self.body = Some(body);
188 self
189 }
190
191 pub fn merge_json_params(mut self, params: Option<Value>) -> Self {
196 if let Some(Value::Object(map)) = params {
197 for (key, value) in map {
198 let string_value = match value {
199 Value::String(s) => s,
200 Value::Number(n) => n.to_string(),
201 Value::Bool(b) => b.to_string(),
202 _ => continue,
203 };
204 self.params.insert(key, string_value);
205 }
206 }
207 self
208 }
209
210 pub async fn execute(self) -> Result<Value> {
218 let mut ctx = SigningContext {
220 method: self.method,
221 endpoint: self.endpoint.clone(),
222 params: self.params,
223 body: self.body,
224 timestamp: String::new(),
225 signature: None,
226 };
227
228 self.strategy.prepare_request(&mut ctx).await?;
230
231 let mut headers = HeaderMap::new();
233 self.strategy.add_auth_headers(&mut headers, &ctx);
234
235 let base_url = self.exchange.base_url();
237 let full_url = format!("{base_url}{}", self.endpoint);
238
239 let client = self.exchange.http_client();
241
242 match self.method {
243 HttpMethod::Get => {
244 let query_string = build_query_string(&ctx.params);
245 let url = if query_string.is_empty() {
246 full_url
247 } else {
248 format!("{full_url}?{query_string}")
249 };
250 client.get(&url, Some(headers)).await
251 }
252 HttpMethod::Post => {
253 let body = ctx.body.unwrap_or_else(|| {
254 serde_json::to_value(&ctx.params).unwrap_or(Value::Object(Map::default()))
255 });
256 client.post(&full_url, Some(headers), Some(body)).await
257 }
258 HttpMethod::Put => {
259 let body = ctx.body.unwrap_or_else(|| {
260 serde_json::to_value(&ctx.params).unwrap_or(Value::Object(Map::default()))
261 });
262 client.put(&full_url, Some(headers), Some(body)).await
263 }
264 HttpMethod::Delete => {
265 let body = ctx.body.unwrap_or_else(|| {
266 serde_json::to_value(&ctx.params).unwrap_or(Value::Object(Map::default()))
267 });
268 client.delete(&full_url, Some(headers), Some(body)).await
269 }
270 }
271 }
272}
273
274pub fn build_query_string(params: &BTreeMap<String, String>) -> String {
278 params
279 .iter()
280 .map(|(k, v)| format!("{k}={}", urlencoding::encode(v)))
281 .collect::<Vec<_>>()
282 .join("&")
283}
284
285pub fn build_query_string_raw(params: &BTreeMap<String, String>) -> String {
287 params
288 .iter()
289 .map(|(k, v)| format!("{k}={v}"))
290 .collect::<Vec<_>>()
291 .join("&")
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn test_http_method_default() {
300 assert_eq!(HttpMethod::default(), HttpMethod::Get);
301 }
302
303 #[test]
304 fn test_http_method_as_str() {
305 assert_eq!(HttpMethod::Get.as_str(), "GET");
306 assert_eq!(HttpMethod::Post.as_str(), "POST");
307 assert_eq!(HttpMethod::Put.as_str(), "PUT");
308 assert_eq!(HttpMethod::Delete.as_str(), "DELETE");
309 }
310
311 #[test]
312 fn test_signing_context_new() {
313 let ctx = SigningContext::new(HttpMethod::Post, "/api/test".to_string());
314 assert_eq!(ctx.method, HttpMethod::Post);
315 assert_eq!(ctx.endpoint, "/api/test");
316 assert!(ctx.params.is_empty());
317 assert!(ctx.body.is_none());
318 assert!(ctx.timestamp.is_empty());
319 assert!(ctx.signature.is_none());
320 }
321
322 #[test]
323 fn test_build_query_string() {
324 let mut params = BTreeMap::new();
325 params.insert("symbol".to_string(), "BTCUSDT".to_string());
326 params.insert("side".to_string(), "BUY".to_string());
327 params.insert("amount".to_string(), "1.5".to_string());
328
329 let query = build_query_string(¶ms);
330 assert_eq!(query, "amount=1.5&side=BUY&symbol=BTCUSDT");
332 }
333
334 #[test]
335 fn test_build_query_string_empty() {
336 let params = BTreeMap::new();
337 let query = build_query_string(¶ms);
338 assert!(query.is_empty());
339 }
340
341 #[test]
342 fn test_build_query_string_with_special_chars() {
343 let mut params = BTreeMap::new();
344 params.insert("symbol".to_string(), "BTC/USDT".to_string());
345
346 let query = build_query_string(¶ms);
347 assert_eq!(query, "symbol=BTC%2FUSDT");
348 }
349
350 #[test]
351 fn test_build_query_string_raw() {
352 let mut params = BTreeMap::new();
353 params.insert("symbol".to_string(), "BTC/USDT".to_string());
354
355 let query = build_query_string_raw(¶ms);
356 assert_eq!(query, "symbol=BTC/USDT");
357 }
358}