use std::collections::HashMap;
use std::error::Error;
use std::hash::Hash;
use std::sync::Mutex;
use crate::model::{ListOrdersRequest, OptionChainSnapshot, Order, TradingAccount};
#[allow(async_fn_in_trait)]
pub trait DataApi {
async fn get_account(&self) -> Result<TradingAccount, Box<dyn Error>>;
async fn get_order(&self, order_id: &str) -> Result<Order, Box<dyn Error>>;
async fn get_order_by_client_id(&self, client_order_id: &str) -> Result<Order, Box<dyn Error>>;
async fn list_orders(&self, request: &ListOrdersRequest) -> Result<Vec<Order>, Box<dyn Error>>;
async fn get_option_chain(
&self,
underlying_symbol: &str,
) -> Result<Vec<OptionChainSnapshot>, Box<dyn Error>>;
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct AccountQuery;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct OrderByIdQuery {
pub order_id: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct OrderByClientIdQuery {
pub client_order_id: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct OptionChainQuery {
pub underlying_symbol: String,
}
#[allow(async_fn_in_trait)]
pub trait Source<Q, T> {
async fn fetch(&self, query: &Q) -> Result<T, Box<dyn Error>>;
}
#[allow(async_fn_in_trait)]
pub trait Store<Q, T> {
async fn get(&self, query: &Q) -> Result<Option<T>, Box<dyn Error>>;
async fn put(&self, query: &Q, value: &T) -> Result<(), Box<dyn Error>>;
}
pub struct Repository<S, C> {
source: S,
store: C,
}
impl<S, C> Repository<S, C> {
pub fn new(source: S, store: C) -> Self {
Repository { source, store }
}
}
impl<S, C> Repository<S, C> {
pub async fn get<Q, T>(&self, query: &Q) -> Result<T, Box<dyn Error>>
where
S: Source<Q, T>,
C: Store<Q, T>,
{
if let Some(cached) = self.store.get(query).await? {
return Ok(cached);
}
let fresh = self.source.fetch(query).await?;
self.store.put(query, &fresh).await?;
Ok(fresh)
}
}
impl<Q, T, S, C> Source<Q, T> for Repository<S, C>
where
S: Source<Q, T>,
C: Store<Q, T>,
{
async fn fetch(&self, query: &Q) -> Result<T, Box<dyn Error>> {
self.get(query).await
}
}
pub struct InMemoryStore<Q, T> {
values: Mutex<HashMap<Q, T>>,
}
impl<Q, T> InMemoryStore<Q, T> {
pub fn new() -> Self {
InMemoryStore {
values: Mutex::new(HashMap::new()),
}
}
}
impl<Q, T> Default for InMemoryStore<Q, T> {
fn default() -> Self {
Self::new()
}
}
impl<Q, T> Store<Q, T> for InMemoryStore<Q, T>
where
Q: Eq + Hash + Clone,
T: Clone,
{
async fn get(&self, query: &Q) -> Result<Option<T>, Box<dyn Error>> {
let values = self
.values
.lock()
.map_err(|_| std::io::Error::other("in-memory store lock poisoned"))?;
Ok(values.get(query).cloned())
}
async fn put(&self, query: &Q, value: &T) -> Result<(), Box<dyn Error>> {
let mut values = self
.values
.lock()
.map_err(|_| std::io::Error::other("in-memory store lock poisoned"))?;
values.insert(query.clone(), value.clone());
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
struct CountingSource {
calls: Arc<AtomicUsize>,
value: String,
}
impl CountingSource {
fn new(calls: Arc<AtomicUsize>, value: String) -> Self {
CountingSource { calls, value }
}
}
impl Source<String, String> for CountingSource {
async fn fetch(&self, _query: &String) -> Result<String, Box<dyn Error>> {
self.calls.fetch_add(1, Ordering::SeqCst);
Ok(self.value.clone())
}
}
#[tokio::test]
async fn test_repository_reads_from_source_then_cache() {
let calls = Arc::new(AtomicUsize::new(0));
let repository = Repository::new(
CountingSource::new(calls.clone(), "fresh".to_string()),
InMemoryStore::<String, String>::new(),
);
let query = "account".to_string();
let first = repository.get(&query).await.unwrap();
let second = repository.get(&query).await.unwrap();
assert_eq!(first, "fresh");
assert_eq!(second, "fresh");
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_in_memory_store_returns_written_value() {
let store = InMemoryStore::<String, String>::new();
let query = "order:1".to_string();
let value = "cached".to_string();
store.put(&query, &value).await.unwrap();
assert_eq!(store.get(&query).await.unwrap(), Some("cached".to_string()));
}
}