use async_trait::async_trait;
use chrono::{DateTime, NaiveDate, Utc};
use sea_orm::sea_query::OnConflict;
use sea_orm::*;
use std::sync::Arc;
use crate::schemas::experiment::{self, ActiveModel, Entity, Model};
use crate::types::Experiment;
#[derive(Debug, Clone)]
pub struct FindOptions {
pub ticker: Option<String>,
pub experiment_type: Option<String>,
pub reviewed: Option<bool>,
pub limit: u64,
pub created_at: Option<DateTime<Utc>>,
pub date: Option<NaiveDate>,
pub before_id: Option<i64>, }
#[async_trait]
pub trait ExperimentRepository: Send + Sync {
async fn insert_many(&self, rows: &Vec<Experiment>)
-> Result<InsertResult<ActiveModel>, DbErr>;
async fn find_many(&self, options: FindOptions) -> Result<Vec<Model>, DbErr>;
async fn update_many(&self, rows: &Vec<Model>) -> Result<UpdateResult, DbErr>;
}
pub struct SeaOrmExperimentRepository {
db_conn: Arc<DatabaseConnection>,
}
impl SeaOrmExperimentRepository {
pub fn new(db_conn: Arc<DatabaseConnection>) -> Self {
Self { db_conn }
}
}
#[async_trait]
impl ExperimentRepository for SeaOrmExperimentRepository {
async fn insert_many(
&self,
rows: &Vec<Experiment>,
) -> Result<InsertResult<ActiveModel>, DbErr> {
let unique_columns = [
experiment::Column::Ticker,
experiment::Column::Date,
experiment::Column::ExperimentType,
];
let on_conflict = OnConflict::columns(unique_columns)
.update_columns([
experiment::Column::JsonString,
experiment::Column::Reviewed,
experiment::Column::CreatedAt,
])
.to_owned();
let active_models = rows.iter().map(move |row| {
let active_model = ActiveModel {
ticker: Set(row.ticker.clone()),
experiment_type: Set(row.experiment_type.clone()),
reviewed: Set(false),
json_string: Set(row.json_string.clone()),
created_at: Set(row.created_at.clone()),
date: Set(row.created_at.date_naive()),
..Default::default()
};
return active_model;
});
let res = Entity::insert_many(active_models)
.on_conflict(on_conflict.clone())
.exec(&*self.db_conn)
.await?;
Ok(res)
}
async fn find_many(&self, options: FindOptions) -> Result<Vec<Model>, DbErr> {
let capped_limit = if options.limit > 100 { 100 } else { options.limit };
let mut query = Entity::find().order_by_desc(experiment::Column::Id);
if let Some(ticker) = options.ticker {
query = query.filter(experiment::Column::Ticker.eq(ticker));
}
if let Some(experiment_type) = options.experiment_type {
query = query.filter(experiment::Column::ExperimentType.eq(experiment_type));
}
if let Some(reviewed) = options.reviewed {
query = query.filter(experiment::Column::Reviewed.eq(reviewed));
}
if let Some(created_at) = options.created_at {
query = query.filter(experiment::Column::CreatedAt.eq(created_at));
}
if let Some(date) = options.date {
query = query.filter(experiment::Column::Date.eq(date));
}
if let Some(before_id) = options.before_id {
query = query.filter(experiment::Column::Id.lt(before_id));
}
query = query.limit(capped_limit);
query.all(&*self.db_conn).await
}
async fn update_many(&self, rows: &Vec<Model>) -> Result<UpdateResult, DbErr> {
let txn = self.db_conn.begin().await?;
let mut total_rows_affected = 0u64;
for row in rows {
let active_model = ActiveModel {
id: Set(row.id),
ticker: Set(row.ticker.clone()),
experiment_type: Set(row.experiment_type.clone()),
reviewed: Set(row.reviewed),
json_string: Set(row.json_string.clone()),
created_at: Set(row.created_at.clone()),
date: Set(row.created_at.date_naive()),
};
let _result = active_model.update(&txn).await?;
total_rows_affected += 1;
}
txn.commit().await?;
Ok(UpdateResult { rows_affected: total_rows_affected })
}
}