use super::Edgar;
use super::error::{EdgarError, Result};
use super::traits::SearchOperations;
use async_trait::async_trait;
use serde::{Deserialize, Deserializer, de};
#[derive(Debug, Clone, Deserialize)]
pub struct SearchResponse {
pub took: u32,
pub timed_out: bool,
pub _shards: Shards,
pub hits: Hits,
}
#[derive(Debug, Clone, Deserialize)]
pub struct Shards {
pub total: u32,
pub successful: u32,
pub skipped: u32,
pub failed: u32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct Hits {
pub total: TotalHits,
#[serde(default)]
pub max_score: Option<f64>,
pub hits: Vec<Hit>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct TotalHits {
pub value: u32,
pub relation: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct Hit {
pub _index: String,
pub _id: String,
#[serde(default)]
pub _score: Option<f64>,
pub _source: Source,
}
#[derive(Debug, Clone, Deserialize)]
pub struct Source {
pub ciks: Vec<String>,
#[serde(default)]
pub period_ending: Option<String>,
pub file_num: Option<Vec<String>>,
pub display_names: Vec<String>,
#[serde(default)]
pub xsl: Option<String>,
#[serde(deserialize_with = "deserialize_sequence")]
pub sequence: u32,
pub root_forms: Vec<String>,
pub file_date: String,
pub biz_states: Vec<String>,
pub sics: Vec<String>,
pub form: String,
pub adsh: String,
pub film_num: Vec<String>,
pub biz_locations: Vec<String>,
pub file_type: String,
#[serde(default)]
pub file_description: Option<String>,
pub inc_states: Vec<String>,
pub items: Option<Vec<String>>,
}
#[derive(Debug, Clone, Default)]
pub struct SearchOptions {
pub keys_typed: Option<String>,
pub query: Option<String>,
pub category: Option<String>,
pub location_code: Option<String>,
pub entity_name: Option<String>,
pub forms: Option<Vec<String>>,
pub location_codes: Option<Vec<String>>,
pub page: Option<u32>,
pub from: Option<u32>,
pub count: Option<u32>,
pub reverse_order: Option<bool>,
pub start_date: Option<String>,
pub end_date: Option<String>,
pub stemming: Option<String>,
pub ciks: Option<Vec<String>>,
pub sic: Option<String>,
pub incorporated_location: Option<bool>,
}
fn deserialize_sequence<'de, D>(deserializer: D) -> std::result::Result<u32, D::Error>
where
D: Deserializer<'de>,
{
struct SequenceVisitor;
impl<'de> de::Visitor<'de> for SequenceVisitor {
type Value = u32;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("an integer or a string containing an integer")
}
fn visit_u64<E>(self, value: u64) -> std::result::Result<Self::Value, E>
where
E: de::Error,
{
Ok(value as u32)
}
fn visit_str<E>(self, value: &str) -> std::result::Result<Self::Value, E>
where
E: de::Error,
{
value.parse().map_err(de::Error::custom)
}
}
deserializer.deserialize_any(SequenceVisitor)
}
impl SearchOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_query(mut self, query: impl Into<String>) -> Self {
self.query = Some(query.into());
self
}
pub fn with_keys_typed(mut self, keys: impl Into<String>) -> Self {
self.keys_typed = Some(keys.into());
self
}
pub fn with_category(mut self, category: impl Into<String>) -> Self {
self.category = Some(category.into());
self
}
pub fn with_location_code(mut self, code: impl Into<String>) -> Self {
self.location_code = Some(code.into());
self
}
pub fn with_entity_name(mut self, name: impl Into<String>) -> Self {
self.entity_name = Some(name.into());
self
}
pub fn with_forms(mut self, forms: Vec<String>) -> Self {
self.forms = Some(forms);
self
}
pub fn with_location_codes(mut self, codes: Vec<String>) -> Self {
self.location_codes = Some(codes);
self
}
pub fn with_page(mut self, page: u32) -> Self {
self.page = Some(page);
self
}
pub fn with_from(mut self, from: u32) -> Self {
self.from = Some(from);
self
}
pub fn with_count(mut self, count: u32) -> Self {
self.count = Some(count);
self
}
pub fn with_reverse_order(mut self, reverse: bool) -> Self {
self.reverse_order = Some(reverse);
self
}
pub fn with_date_range(mut self, start_date: String, end_date: String) -> Self {
self.start_date = Some(start_date);
self.end_date = Some(end_date);
self
}
pub fn with_stemming(mut self, stemming: impl Into<String>) -> Self {
self.stemming = Some(stemming.into());
self
}
pub fn with_ciks<T>(mut self, ciks: T) -> Self
where
T: Into<Vec<String>>,
{
self.ciks = Some(ciks.into());
self
}
pub fn with_cik(self, cik: impl Into<String>) -> Self {
self.with_ciks(vec![cik.into()])
}
pub fn with_sic(mut self, sic: impl Into<String>) -> Self {
self.sic = Some(sic.into());
self
}
pub fn with_incorporated_location(mut self, incorporated: bool) -> Self {
self.incorporated_location = Some(incorporated);
self
}
pub fn to_query_params(&self) -> Vec<(String, String)> {
let mut params = Vec::new();
if let Some(ref query) = self.query {
params.push(("q".to_string(), query.clone()));
}
if let Some(ref keys) = self.keys_typed {
params.push(("keysTyped".to_string(), keys.clone()));
}
if let Some(ref category) = self.category {
params.push(("category".to_string(), category.clone()));
}
if let Some(ref code) = self.location_code {
params.push(("locationCode".to_string(), code.clone()));
}
if let Some(ref name) = self.entity_name {
params.push(("entityName".to_string(), name.clone()));
}
if let Some(ref forms) = self.forms {
params.push(("forms".to_string(), forms.join(",")));
}
if let Some(ref codes) = self.location_codes {
params.push(("locationCodes".to_string(), codes.join(",")));
}
if let Some(page) = self.page {
params.push(("page".to_string(), page.to_string()));
}
if let Some(from) = self.from {
params.push(("from".to_string(), from.to_string()));
}
if let Some(count) = self.count {
params.push(("count".to_string(), count.to_string()));
}
if let Some(reverse) = self.reverse_order {
params.push((
"reverse_order".to_string(),
if reverse { "TRUE" } else { "FALSE" }.to_string(),
));
}
if let Some(ref start) = self.start_date {
params.push(("startdt".to_string(), start.clone()));
}
if let Some(ref end) = self.end_date {
params.push(("enddt".to_string(), end.clone()));
}
if let Some(ref stemming) = self.stemming {
params.push(("stemming".to_string(), stemming.clone()));
}
if let Some(ref ciks) = self.ciks {
params.push(("ciks".to_string(), ciks.join(",")));
}
if let Some(ref sic) = self.sic {
params.push(("sic".to_string(), sic.clone()));
}
if let Some(incorporated) = self.incorporated_location {
params.push((
"incorporated_location".to_string(),
incorporated.to_string(),
));
}
params
}
}
#[async_trait]
impl SearchOperations for Edgar {
async fn search(&self, options: SearchOptions) -> Result<SearchResponse> {
let params = options.to_query_params();
let query_string = serde_urlencoded::to_string(¶ms)
.map_err(|e| EdgarError::InvalidResponse(e.to_string()))?;
let url = format!("{}?{}", self.search_url(), query_string);
let response = self.get(&url).await?;
Ok(serde_json::from_str(&response)?)
}
async fn search_all(&self, mut options: SearchOptions) -> Result<Vec<Hit>> {
const BATCH_SIZE: u32 = 7; const PAGE_SIZE: u32 = 100;
options.count = Some(PAGE_SIZE);
options.page = Some(1);
options.reverse_order = Some(false);
let initial_response = self.search(options.clone()).await?;
let total_hits = initial_response.hits.total.value;
tracing::info!("Found {} total hits", total_hits);
let mut all_hits = Vec::with_capacity(total_hits as usize);
all_hits.extend(initial_response.hits.hits);
let total_pages = (total_hits + PAGE_SIZE - 1) / PAGE_SIZE;
let mut current_page = 1;
while current_page < total_pages {
let end_page = (current_page + BATCH_SIZE).min(total_pages);
let mut batch_futures = Vec::with_capacity((end_page - current_page) as usize);
for page in (current_page + 1)..=end_page {
let skip = (page - 1) * PAGE_SIZE;
if skip >= total_hits {
break;
}
let mut page_options = options.clone();
page_options.page = Some(page);
page_options.from = Some(skip);
page_options.count = Some(PAGE_SIZE.min(total_hits - skip));
page_options.reverse_order = Some(false);
batch_futures.push(self.search(page_options));
}
if batch_futures.is_empty() {
break;
}
let results = futures_util::future::join_all(batch_futures).await;
for result in results {
match result {
Ok(response) => {
all_hits.extend(response.hits.hits);
}
Err(e) => {
tracing::error!("Error fetching page: {}", e);
return Err(e);
}
}
}
current_page += BATCH_SIZE;
}
Ok(all_hits)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_search_options_builder() {
let options = SearchOptions::new()
.with_query("test")
.with_forms(vec!["10-K".to_string(), "10-Q".to_string()])
.with_count(10)
.with_reverse_order(true);
let params = options.to_query_params();
assert!(params.contains(&("q".to_string(), "test".to_string())));
assert!(params.contains(&("forms".to_string(), "10-K,10-Q".to_string())));
assert!(params.contains(&("count".to_string(), "10".to_string())));
assert!(params.contains(&("reverse_order".to_string(), "TRUE".to_string())));
}
}