use crate::{
error::{OpenFIGIError, OtherErrorKind, Result},
impl_filter_builder,
model::{
enums::{
Currency, ExchCode, MarketSecDesc, MicCode, OptionType, SecurityType, SecurityType2,
StateCode,
},
request::common::RequestFilters,
},
};
use chrono::NaiveDate;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SearchRequest {
pub query: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub start: Option<String>,
#[serde(flatten)]
pub filters: RequestFilters,
}
impl SearchRequest {
#[must_use]
pub fn new(query: impl Into<String>) -> Self {
Self {
query: query.into(),
start: None,
filters: RequestFilters::default(),
}
}
#[must_use]
pub fn builder() -> SearchRequestBuilder {
SearchRequestBuilder::new()
}
pub fn validate(&self) -> Result<()> {
self.filters.validate()?;
Ok(())
}
}
#[derive(Default)]
pub struct SearchRequestBuilder {
query: Option<String>,
start: Option<String>,
filters: RequestFilters,
}
impl SearchRequestBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn query(mut self, query: impl Into<String>) -> Self {
self.query = Some(query.into());
self
}
#[must_use]
pub fn start(mut self, start: impl Into<String>) -> Self {
self.start = Some(start.into());
self
}
pub fn filters_mut(&mut self) -> &mut RequestFilters {
&mut self.filters
}
impl_filter_builder!();
pub fn build(self) -> Result<SearchRequest> {
let query = self.query.ok_or_else(|| {
OpenFIGIError::other_error(OtherErrorKind::Validation, "query is required")
})?;
let request = SearchRequest {
query,
start: self.start,
filters: self.filters,
};
request.validate()?;
Ok(request)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::enums::{Currency, ExchCode, MicCode, SecurityType2};
use chrono::NaiveDate;
#[test]
fn test_search_request_new_minimal() {
let request = SearchRequest::new("ibm");
assert_eq!(request.query, "ibm");
assert!(request.start.is_none());
assert!(request.filters.exch_code.is_none());
assert!(request.filters.mic_code.is_none());
}
#[test]
fn test_search_request_builder_minimal() {
let request = SearchRequest::builder().query("ibm").build().unwrap();
assert_eq!(request.query, "ibm");
}
#[test]
fn test_search_request_builder_with_currency() {
let request = SearchRequest::builder()
.query("ibm")
.currency(Currency::USD)
.build()
.unwrap();
assert_eq!(request.filters.currency, Some(Currency::USD));
}
#[test]
fn test_search_request_validate_exch_and_mic_code_conflict() {
let mut request = SearchRequest::new("ibm");
request.filters.exch_code = Some(ExchCode::A0);
request.filters.mic_code = Some(MicCode::XCME);
let result = request.validate();
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(msg.contains("Cannot set both exchCode and micCode"));
}
#[test]
fn test_search_request_validate_strike_range() {
let mut request = SearchRequest::new("ibm");
request.filters.strike = Some([Some(10.0), Some(5.0)]);
let result = request.validate();
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(msg.contains("strike: start value cannot be greater than end value"));
}
#[test]
fn test_search_request_validate_expiration_required_for_option() {
let mut request = SearchRequest::new("ibm");
request.filters.security_type2 = Some(SecurityType2::Option);
request.filters.expiration = None;
let result = request.validate();
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(msg.contains("expiration is required for Option or Warrant security types"));
}
#[test]
fn test_search_request_validate_maturity_required_for_pool() {
let mut request = SearchRequest::new("ibm");
request.filters.security_type2 = Some(SecurityType2::Pool);
let result = request.validate();
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(msg.contains("maturity is required for Pool security types"));
}
#[test]
fn test_search_request_validate_date_range_too_long() {
let mut request = SearchRequest::new("ibm");
let start = NaiveDate::from_ymd_opt(2025, 1, 1).unwrap();
let end = NaiveDate::from_ymd_opt(2026, 2, 1).unwrap();
request.filters.expiration = Some([Some(start), Some(end)]);
let result = request.validate();
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(msg.contains("date range cannot exceed 1 year"));
}
#[test]
fn test_serialize_deserialize_search_request() {
let request = SearchRequest::builder()
.query("ibm")
.currency(Currency::USD)
.build()
.unwrap();
let serialized = serde_json::to_string(&request).unwrap();
let deserialized: SearchRequest = serde_json::from_str(&serialized).unwrap();
assert_eq!(request, deserialized);
}
}