use std::fmt;
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use alpaca_core::QueryWriter;
use crate::Error;
use super::{
OrderClass, OrderSide, OrderType, PositionIntent, QueryOrderStatus, SortDirection, StopLoss,
TakeProfit, TimeInForce,
};
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct ListRequest {
pub status: Option<QueryOrderStatus>,
pub limit: Option<u32>,
pub after: Option<String>,
pub until: Option<String>,
pub direction: Option<SortDirection>,
pub nested: Option<bool>,
pub symbols: Option<Vec<String>>,
pub side: Option<OrderSide>,
pub asset_class: Option<String>,
}
impl ListRequest {
pub(crate) fn into_query(self) -> Result<Vec<(String, String)>, Error> {
let mut query = QueryWriter::default();
query.push_opt("status", self.status);
query.push_opt("limit", validate_limit(self.limit, 1, 500)?);
query.push_opt("after", validate_optional_text("after", self.after)?);
query.push_opt("until", validate_optional_text("until", self.until)?);
query.push_opt("direction", self.direction);
query.push_opt("nested", self.nested);
if let Some(symbols) = validate_optional_symbols(self.symbols)? {
query.push_csv("symbols", symbols);
}
query.push_opt("side", self.side);
query.push_opt(
"asset_class",
validate_optional_text("asset_class", self.asset_class)?,
);
Ok(query.finish())
}
}
#[derive(Clone, Debug, Default, PartialEq, Serialize)]
pub struct CreateRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub symbol: Option<String>,
#[serde(
skip_serializing_if = "Option::is_none",
serialize_with = "alpaca_core::decimal::string_contract::serialize_option_decimal"
)]
pub qty: Option<Decimal>,
#[serde(
skip_serializing_if = "Option::is_none",
serialize_with = "alpaca_core::decimal::string_contract::serialize_option_decimal"
)]
pub notional: Option<Decimal>,
#[serde(skip_serializing_if = "Option::is_none")]
pub side: Option<OrderSide>,
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
pub r#type: Option<OrderType>,
#[serde(skip_serializing_if = "Option::is_none")]
pub time_in_force: Option<TimeInForce>,
#[serde(
skip_serializing_if = "Option::is_none",
serialize_with = "alpaca_core::decimal::string_contract::serialize_option_decimal"
)]
pub limit_price: Option<Decimal>,
#[serde(
skip_serializing_if = "Option::is_none",
serialize_with = "alpaca_core::decimal::string_contract::serialize_option_decimal"
)]
pub stop_price: Option<Decimal>,
#[serde(
skip_serializing_if = "Option::is_none",
serialize_with = "alpaca_core::decimal::string_contract::serialize_option_decimal"
)]
pub trail_price: Option<Decimal>,
#[serde(
skip_serializing_if = "Option::is_none",
serialize_with = "alpaca_core::decimal::string_contract::serialize_option_decimal"
)]
pub trail_percent: Option<Decimal>,
#[serde(skip_serializing_if = "Option::is_none")]
pub extended_hours: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_order_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub order_class: Option<OrderClass>,
#[serde(skip_serializing_if = "Option::is_none")]
pub take_profit: Option<TakeProfit>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_loss: Option<StopLoss>,
#[serde(skip_serializing_if = "Option::is_none")]
pub legs: Option<Vec<OptionLegRequest>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub position_intent: Option<PositionIntent>,
}
impl CreateRequest {
pub(crate) fn into_json(self) -> Result<serde_json::Value, Error> {
self.validate()?;
serde_json::to_value(self).map_err(|error| Error::InvalidRequest(error.to_string()))
}
pub(crate) fn validate(&self) -> Result<(), Error> {
if let Some(symbol) = &self.symbol {
validate_required_text("symbol", symbol)?;
}
if let Some(client_order_id) = &self.client_order_id {
validate_required_text("client_order_id", client_order_id)?;
}
if let Some(legs) = &self.legs {
for leg in legs {
leg.validate()?;
}
}
if self.order_class == Some(OrderClass::Mleg) {
validate_mleg_legs(self.legs.as_deref())?;
}
Ok(())
}
}
#[derive(Clone, Debug, Default, PartialEq, Serialize)]
pub struct ReplaceRequest {
#[serde(
skip_serializing_if = "Option::is_none",
serialize_with = "alpaca_core::decimal::string_contract::serialize_option_decimal"
)]
pub qty: Option<Decimal>,
#[serde(skip_serializing_if = "Option::is_none")]
pub time_in_force: Option<TimeInForce>,
#[serde(
skip_serializing_if = "Option::is_none",
serialize_with = "alpaca_core::decimal::string_contract::serialize_option_decimal"
)]
pub limit_price: Option<Decimal>,
#[serde(
skip_serializing_if = "Option::is_none",
serialize_with = "alpaca_core::decimal::string_contract::serialize_option_decimal"
)]
pub stop_price: Option<Decimal>,
#[serde(
skip_serializing_if = "Option::is_none",
serialize_with = "alpaca_core::decimal::string_contract::serialize_option_decimal"
)]
pub trail: Option<Decimal>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_order_id: Option<String>,
}
impl ReplaceRequest {
pub(crate) fn into_json(self) -> Result<serde_json::Value, Error> {
self.validate()?;
serde_json::to_value(self).map_err(|error| Error::InvalidRequest(error.to_string()))
}
pub(crate) fn validate(&self) -> Result<(), Error> {
if let Some(client_order_id) = &self.client_order_id {
validate_required_text("client_order_id", client_order_id)?;
}
Ok(())
}
}
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
pub struct OptionLegRequest {
pub symbol: String,
#[serde(
deserialize_with = "alpaca_core::integer::deserialize_u32_from_string_or_number",
serialize_with = "alpaca_core::integer::string_contract::serialize_u32"
)]
pub ratio_qty: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub side: Option<OrderSide>,
#[serde(skip_serializing_if = "Option::is_none")]
pub position_intent: Option<PositionIntent>,
}
impl OptionLegRequest {
fn validate(&self) -> Result<(), Error> {
validate_required_text("symbol", &self.symbol)?;
if self.ratio_qty == 0 {
return Err(Error::InvalidRequest(
"ratio_qty must be greater than 0".to_owned(),
));
}
Ok(())
}
}
pub(crate) fn validate_order_id(order_id: &str) -> Result<String, Error> {
validate_required_path_segment("order_id", order_id)
}
pub(crate) fn validate_client_order_id(client_order_id: &str) -> Result<String, Error> {
validate_required_path_segment("client_order_id", client_order_id)
}
fn validate_optional_text(
name: &'static str,
value: Option<String>,
) -> Result<Option<String>, Error> {
value
.map(|value| validate_required_text(name, &value))
.transpose()
}
fn validate_optional_symbols(value: Option<Vec<String>>) -> Result<Option<Vec<String>>, Error> {
match value {
None => Ok(None),
Some(values) if values.is_empty() => Err(Error::InvalidRequest(
"symbols must contain at least one symbol".to_owned(),
)),
Some(values) => values
.into_iter()
.map(|value| validate_required_text("symbols", &value))
.collect::<Result<Vec<_>, Error>>()
.map(Some),
}
}
fn validate_required_text(name: &str, value: &str) -> Result<String, Error> {
let trimmed = value.trim();
if trimmed.is_empty() {
return Err(Error::InvalidRequest(format!(
"{name} must not be empty or whitespace-only"
)));
}
Ok(trimmed.to_owned())
}
fn validate_required_path_segment(name: &str, value: &str) -> Result<String, Error> {
let value = validate_required_text(name, value)?;
if value.contains('/') {
return Err(Error::InvalidRequest(format!(
"{name} must not contain `/`"
)));
}
Ok(value)
}
fn validate_limit(limit: Option<u32>, min: u32, max: u32) -> Result<Option<u32>, Error> {
match limit {
Some(limit) if !(min..=max).contains(&limit) => Err(Error::InvalidRequest(format!(
"limit must be between {min} and {max}"
))),
_ => Ok(limit),
}
}
fn validate_mleg_legs(legs: Option<&[OptionLegRequest]>) -> Result<(), Error> {
let legs = legs.ok_or_else(|| {
Error::InvalidRequest(
"legs must contain 2 to 4 option legs when order_class is mleg".to_owned(),
)
})?;
if !(2..=4).contains(&legs.len()) {
return Err(Error::InvalidRequest(
"legs must contain 2 to 4 option legs when order_class is mleg".to_owned(),
));
}
let gcd = legs.iter().fold(0, |current, leg| {
if current == 0 {
leg.ratio_qty
} else {
greatest_common_divisor(current, leg.ratio_qty)
}
});
if gcd != 1 {
return Err(Error::InvalidRequest(
"ratio_qty values across mleg legs must use the simplest whole-number ratio".to_owned(),
));
}
Ok(())
}
fn greatest_common_divisor(lhs: u32, rhs: u32) -> u32 {
let mut lhs = lhs;
let mut rhs = rhs;
while rhs != 0 {
let remainder = lhs % rhs;
lhs = rhs;
rhs = remainder;
}
lhs
}
impl fmt::Display for QueryOrderStatus {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str(match self {
Self::Open => "open",
Self::Closed => "closed",
Self::All => "all",
})
}
}
impl fmt::Display for SortDirection {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str(match self {
Self::Asc => "asc",
Self::Desc => "desc",
})
}
}
impl fmt::Display for OrderSide {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str(match self {
Self::Buy => "buy",
Self::Sell => "sell",
Self::Unspecified => "",
})
}
}