1use crate::broker::{
10 Account, BrokerClient, BrokerError, HealthStatus, OrderFilter, Position, PositionSide,
11};
12use crate::{OrderRequest, OrderResponse, OrderSide, OrderStatus, OrderType, OrderUpdate, Symbol, TimeInForce};
13use async_trait::async_trait;
14use chrono::{DateTime, Utc};
15use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
16use reqwest::{Client, Method, StatusCode};
17use rust_decimal::Decimal;
18use serde::{Deserialize, Serialize};
19use std::num::NonZeroU32;
20use std::sync::Arc;
21use std::time::Duration;
22use tokio::sync::RwLock;
23use tracing::{debug, error, info, warn};
24
25pub struct AlpacaBroker {
27 client: Client,
28 base_url: String,
29 api_key: String,
30 secret_key: String,
31 rate_limiter: DefaultDirectRateLimiter,
32 paper_trading: bool,
33}
34
35impl AlpacaBroker {
36 pub fn new(api_key: String, secret_key: String, paper_trading: bool) -> Self {
38 let base_url = if paper_trading {
39 "https://paper-api.alpaca.markets".to_string()
40 } else {
41 "https://api.alpaca.markets".to_string()
42 };
43
44 let client = Client::builder()
45 .timeout(Duration::from_secs(30))
46 .build()
47 .expect("Failed to create HTTP client");
48
49 let quota = Quota::per_minute(NonZeroU32::new(200).unwrap());
51 let rate_limiter = RateLimiter::direct(quota);
52
53 Self {
54 client,
55 base_url,
56 api_key,
57 secret_key,
58 rate_limiter,
59 paper_trading,
60 }
61 }
62
63 async fn request<T: serde::de::DeserializeOwned>(
65 &self,
66 method: Method,
67 path: &str,
68 body: Option<impl Serialize>,
69 ) -> Result<T, BrokerError> {
70 self.rate_limiter.until_ready().await;
72
73 let url = format!("{}{}", self.base_url, path);
74
75 let mut req = self
76 .client
77 .request(method.clone(), &url)
78 .header("APCA-API-KEY-ID", &self.api_key)
79 .header("APCA-API-SECRET-KEY", &self.secret_key);
80
81 if let Some(body) = body {
82 req = req.json(&body);
83 }
84
85 debug!("Alpaca API request: {} {}", method, path);
86
87 let response = req.send().await?;
88 let status = response.status();
89
90 match status {
91 StatusCode::OK | StatusCode::CREATED => {
92 let result = response.json().await?;
93 Ok(result)
94 }
95 StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
96 Err(BrokerError::Auth("Invalid API keys".to_string()))
97 }
98 StatusCode::TOO_MANY_REQUESTS => Err(BrokerError::RateLimit),
99 StatusCode::NOT_FOUND => {
100 let error_text = response.text().await.unwrap_or_default();
101 Err(BrokerError::OrderNotFound(error_text))
102 }
103 StatusCode::UNPROCESSABLE_ENTITY => {
104 let error_text = response.text().await.unwrap_or_default();
105 if error_text.contains("insufficient") {
106 Err(BrokerError::InsufficientFunds)
107 } else {
108 Err(BrokerError::InvalidOrder(error_text))
109 }
110 }
111 _ => {
112 let error_text = response.text().await.unwrap_or_default();
113 error!("Alpaca API error: {} - {}", status, error_text);
114 Err(BrokerError::Network(error_text))
115 }
116 }
117 }
118}
119
120#[async_trait]
121impl BrokerClient for AlpacaBroker {
122 async fn get_account(&self) -> Result<Account, BrokerError> {
123 #[derive(Deserialize)]
124 struct AlpacaAccount {
125 id: String,
126 cash: String,
127 portfolio_value: String,
128 buying_power: String,
129 equity: String,
130 last_equity: String,
131 multiplier: String,
132 currency: String,
133 shorting_enabled: bool,
134 long_market_value: String,
135 short_market_value: String,
136 initial_margin: String,
137 maintenance_margin: String,
138 daytrade_count: i32,
139 daytrading_buying_power: String,
140 }
141
142 let account: AlpacaAccount = self.request(Method::GET, "/v2/account", None::<()>).await?;
143
144 Ok(Account {
145 account_id: account.id,
146 cash: Decimal::from_str_exact(&account.cash).unwrap_or_default(),
147 portfolio_value: Decimal::from_str_exact(&account.portfolio_value)
148 .unwrap_or_default(),
149 buying_power: Decimal::from_str_exact(&account.buying_power).unwrap_or_default(),
150 equity: Decimal::from_str_exact(&account.equity).unwrap_or_default(),
151 last_equity: Decimal::from_str_exact(&account.last_equity).unwrap_or_default(),
152 multiplier: account.multiplier,
153 currency: account.currency,
154 shorting_enabled: account.shorting_enabled,
155 long_market_value: Decimal::from_str_exact(&account.long_market_value)
156 .unwrap_or_default(),
157 short_market_value: Decimal::from_str_exact(&account.short_market_value)
158 .unwrap_or_default(),
159 initial_margin: Decimal::from_str_exact(&account.initial_margin).unwrap_or_default(),
160 maintenance_margin: Decimal::from_str_exact(&account.maintenance_margin)
161 .unwrap_or_default(),
162 day_trading_buying_power: Decimal::from_str_exact(&account.daytrading_buying_power)
163 .unwrap_or_default(),
164 daytrade_count: account.daytrade_count,
165 })
166 }
167
168 async fn get_positions(&self) -> Result<Vec<Position>, BrokerError> {
169 #[derive(Deserialize)]
170 struct AlpacaPosition {
171 symbol: String,
172 qty: String,
173 side: String,
174 avg_entry_price: String,
175 market_value: String,
176 cost_basis: String,
177 unrealized_pl: String,
178 unrealized_plpc: String,
179 current_price: String,
180 lastday_price: String,
181 change_today: String,
182 }
183
184 let positions: Vec<AlpacaPosition> =
185 self.request(Method::GET, "/v2/positions", None::<()>)
186 .await?;
187
188 Ok(positions
189 .into_iter()
190 .map(|pos| Position {
191 symbol: Symbol::new(&pos.symbol).expect("Invalid symbol from Alpaca"),
192 qty: pos.qty.parse().unwrap_or(0),
193 side: match pos.side.as_str() {
194 "long" => PositionSide::Long,
195 "short" => PositionSide::Short,
196 _ => PositionSide::Long,
197 },
198 avg_entry_price: Decimal::from_str_exact(&pos.avg_entry_price)
199 .unwrap_or_default(),
200 market_value: Decimal::from_str_exact(&pos.market_value).unwrap_or_default(),
201 cost_basis: Decimal::from_str_exact(&pos.cost_basis).unwrap_or_default(),
202 unrealized_pl: Decimal::from_str_exact(&pos.unrealized_pl).unwrap_or_default(),
203 unrealized_plpc: Decimal::from_str_exact(&pos.unrealized_plpc).unwrap_or_default(),
204 current_price: Decimal::from_str_exact(&pos.current_price).unwrap_or_default(),
205 lastday_price: Decimal::from_str_exact(&pos.lastday_price).unwrap_or_default(),
206 change_today: Decimal::from_str_exact(&pos.change_today).unwrap_or_default(),
207 })
208 .collect())
209 }
210
211 async fn place_order(&self, order: OrderRequest) -> Result<OrderResponse, BrokerError> {
212 #[derive(Serialize)]
213 struct AlpacaOrderRequest {
214 symbol: String,
215 qty: String,
216 side: String,
217 #[serde(rename = "type")]
218 order_type: String,
219 time_in_force: String,
220 #[serde(skip_serializing_if = "Option::is_none")]
221 limit_price: Option<String>,
222 #[serde(skip_serializing_if = "Option::is_none")]
223 stop_price: Option<String>,
224 }
225
226 #[derive(Deserialize)]
227 struct AlpacaOrderResponse {
228 id: String,
229 client_order_id: String,
230 symbol: String,
231 qty: String,
232 side: String,
233 status: String,
234 filled_qty: String,
235 filled_avg_price: Option<String>,
236 submitted_at: String,
237 filled_at: Option<String>,
238 }
239
240 let alpaca_order = AlpacaOrderRequest {
241 symbol: order.symbol.as_str().to_string(),
242 qty: order.quantity.to_string(),
243 side: match order.side {
244 OrderSide::Buy => "buy".to_string(),
245 OrderSide::Sell => "sell".to_string(),
246 },
247 order_type: match order.order_type {
248 OrderType::Market => "market".to_string(),
249 OrderType::Limit => "limit".to_string(),
250 OrderType::StopLoss => "stop".to_string(),
251 OrderType::StopLimit => "stop_limit".to_string(),
252 },
253 time_in_force: match order.time_in_force {
254 TimeInForce::Day => "day".to_string(),
255 TimeInForce::GTC => "gtc".to_string(),
256 TimeInForce::IOC => "ioc".to_string(),
257 TimeInForce::FOK => "fok".to_string(),
258 },
259 limit_price: order.limit_price.map(|p| p.to_string()),
260 stop_price: order.stop_price.map(|p| p.to_string()),
261 };
262
263 let response: AlpacaOrderResponse = self
264 .request(Method::POST, "/v2/orders", Some(alpaca_order))
265 .await?;
266
267 info!("Order placed on Alpaca: {}", response.id);
268
269 Ok(OrderResponse {
270 order_id: response.id,
271 client_order_id: response.client_order_id,
272 status: parse_order_status(&response.status),
273 filled_qty: response.filled_qty.parse().unwrap_or(0),
274 filled_avg_price: response
275 .filled_avg_price
276 .and_then(|p| Decimal::from_str_exact(&p).ok()),
277 submitted_at: DateTime::parse_from_rfc3339(&response.submitted_at)
278 .unwrap()
279 .with_timezone(&Utc),
280 filled_at: response
281 .filled_at
282 .and_then(|t| DateTime::parse_from_rfc3339(&t).ok())
283 .map(|dt| dt.with_timezone(&Utc)),
284 })
285 }
286
287 async fn cancel_order(&self, order_id: &str) -> Result<(), BrokerError> {
288 let path = format!("/v2/orders/{}", order_id);
289 let _: serde_json::Value = self.request(Method::DELETE, &path, None::<()>).await?;
290
291 info!("Order cancelled on Alpaca: {}", order_id);
292 Ok(())
293 }
294
295 async fn get_order(&self, order_id: &str) -> Result<OrderResponse, BrokerError> {
296 #[derive(Deserialize)]
297 struct AlpacaOrderResponse {
298 id: String,
299 client_order_id: String,
300 symbol: String,
301 qty: String,
302 side: String,
303 status: String,
304 filled_qty: String,
305 filled_avg_price: Option<String>,
306 submitted_at: String,
307 filled_at: Option<String>,
308 }
309
310 let path = format!("/v2/orders/{}", order_id);
311 let response: AlpacaOrderResponse = self.request(Method::GET, &path, None::<()>).await?;
312
313 Ok(OrderResponse {
314 order_id: response.id,
315 client_order_id: response.client_order_id,
316 status: parse_order_status(&response.status),
317 filled_qty: response.filled_qty.parse().unwrap_or(0),
318 filled_avg_price: response
319 .filled_avg_price
320 .and_then(|p| Decimal::from_str_exact(&p).ok()),
321 submitted_at: DateTime::parse_from_rfc3339(&response.submitted_at)
322 .unwrap()
323 .with_timezone(&Utc),
324 filled_at: response
325 .filled_at
326 .and_then(|t| DateTime::parse_from_rfc3339(&t).ok())
327 .map(|dt| dt.with_timezone(&Utc)),
328 })
329 }
330
331 async fn list_orders(&self, filter: OrderFilter) -> Result<Vec<OrderResponse>, BrokerError> {
332 #[derive(Deserialize)]
333 struct AlpacaOrderResponse {
334 id: String,
335 client_order_id: String,
336 symbol: String,
337 qty: String,
338 side: String,
339 status: String,
340 filled_qty: String,
341 filled_avg_price: Option<String>,
342 submitted_at: String,
343 filled_at: Option<String>,
344 }
345
346 let mut path = "/v2/orders".to_string();
347 let mut params = Vec::new();
348
349 if let Some(status) = filter.status {
350 params.push(format!("status={:?}", status).to_lowercase());
351 }
352 if let Some(limit) = filter.limit {
353 params.push(format!("limit={}", limit));
354 }
355
356 if !params.is_empty() {
357 path.push('?');
358 path.push_str(¶ms.join("&"));
359 }
360
361 let orders: Vec<AlpacaOrderResponse> = self.request(Method::GET, &path, None::<()>).await?;
362
363 Ok(orders
364 .into_iter()
365 .map(|order| OrderResponse {
366 order_id: order.id,
367 client_order_id: order.client_order_id,
368 status: parse_order_status(&order.status),
369 filled_qty: order.filled_qty.parse().unwrap_or(0),
370 filled_avg_price: order
371 .filled_avg_price
372 .and_then(|p| Decimal::from_str_exact(&p).ok()),
373 submitted_at: DateTime::parse_from_rfc3339(&order.submitted_at)
374 .unwrap()
375 .with_timezone(&Utc),
376 filled_at: order
377 .filled_at
378 .and_then(|t| DateTime::parse_from_rfc3339(&t).ok())
379 .map(|dt| dt.with_timezone(&Utc)),
380 })
381 .collect())
382 }
383
384 async fn health_check(&self) -> Result<HealthStatus, BrokerError> {
385 let _: serde_json::Value = self.request(Method::GET, "/v2/clock", None::<()>).await?;
387
388 Ok(HealthStatus::Healthy)
389 }
390}
391
392fn parse_order_status(status: &str) -> OrderStatus {
394 match status {
395 "new" | "pending_new" => OrderStatus::Pending,
396 "accepted" => OrderStatus::Accepted,
397 "partially_filled" => OrderStatus::PartiallyFilled,
398 "filled" => OrderStatus::Filled,
399 "canceled" | "pending_cancel" => OrderStatus::Cancelled,
400 "rejected" => OrderStatus::Rejected,
401 "expired" => OrderStatus::Expired,
402 _ => {
403 warn!("Unknown order status: {}", status);
404 OrderStatus::Pending
405 }
406 }
407}
408
409#[cfg(test)]
410mod tests {
411 use super::*;
412
413 #[test]
414 fn test_parse_order_status() {
415 assert_eq!(parse_order_status("new"), OrderStatus::Pending);
416 assert_eq!(parse_order_status("filled"), OrderStatus::Filled);
417 assert_eq!(parse_order_status("canceled"), OrderStatus::Cancelled);
418 }
419
420 #[tokio::test]
421 async fn test_alpaca_broker_creation() {
422 let broker = AlpacaBroker::new(
423 "test_key".to_string(),
424 "test_secret".to_string(),
425 true,
426 );
427
428 assert!(broker.paper_trading);
429 assert_eq!(broker.base_url, "https://paper-api.alpaca.markets");
430 }
431}