use crate::error::Result;
use sea_orm::{prelude::*, ActiveValue, DatabaseConnection, TransactionTrait};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct IndexViolation {
pub kind: ViolationKind,
pub key: String,
pub first_index: u64,
pub current_index: u64,
pub context: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ViolationKind {
DuplicateKey,
ConflictingValue,
}
pub struct ContentIndex {
index: RwLock<HashMap<String, (u64, Option<String>)>>,
pending: RwLock<Vec<PendingEntry>>,
name: String,
}
#[derive(Debug, Clone)]
struct PendingEntry {
key: String,
index: u64,
value: Option<String>,
}
impl ContentIndex {
pub fn new(name: impl Into<String>) -> Self {
Self {
index: RwLock::new(HashMap::new()),
pending: RwLock::new(Vec::new()),
name: name.into(),
}
}
pub fn name(&self) -> &str {
&self.name
}
pub async fn check(
&self,
key: &str,
current_index: u64,
expected_value: Option<&str>,
) -> Option<IndexViolation> {
let pending = self.pending.read().await;
for entry in pending.iter() {
if entry.key == key {
if let Some(expected) = expected_value {
if entry.value.as_deref() != Some(expected) {
return Some(IndexViolation {
kind: ViolationKind::ConflictingValue,
key: key.to_string(),
first_index: entry.index,
current_index,
context: entry.value.clone(),
});
}
}
return Some(IndexViolation {
kind: ViolationKind::DuplicateKey,
key: key.to_string(),
first_index: entry.index,
current_index,
context: None,
});
}
}
drop(pending);
let index = self.index.read().await;
if let Some((first_index, stored_value)) = index.get(key) {
if let Some(expected) = expected_value {
if stored_value.as_deref() != Some(expected) {
return Some(IndexViolation {
kind: ViolationKind::ConflictingValue,
key: key.to_string(),
first_index: *first_index,
current_index,
context: stored_value.clone(),
});
}
}
return Some(IndexViolation {
kind: ViolationKind::DuplicateKey,
key: key.to_string(),
first_index: *first_index,
current_index,
context: None,
});
}
None
}
pub async fn contains(&self, key: &str) -> bool {
let pending = self.pending.read().await;
if pending.iter().any(|e| e.key == key) {
return true;
}
drop(pending);
let index = self.index.read().await;
index.contains_key(key)
}
pub async fn stage(&self, key: String, index: u64, value: Option<String>) {
let mut pending = self.pending.write().await;
pending.push(PendingEntry { key, index, value });
}
pub async fn commit(&self) {
let mut pending = self.pending.write().await;
let mut index = self.index.write().await;
for entry in pending.drain(..) {
index.insert(entry.key, (entry.index, entry.value));
}
}
pub async fn rollback(&self) {
let mut pending = self.pending.write().await;
pending.clear();
}
pub async fn len(&self) -> usize {
let index = self.index.read().await;
index.len()
}
pub async fn is_empty(&self) -> bool {
let index = self.index.read().await;
index.is_empty()
}
pub async fn pending_count(&self) -> usize {
let pending = self.pending.read().await;
pending.len()
}
pub async fn load_from(&self, data: HashMap<String, (u64, Option<String>)>) {
let mut index = self.index.write().await;
*index = data;
}
pub async fn commit_and_drain(&self) -> Vec<(String, u64, Option<String>)> {
let mut pending = self.pending.write().await;
let mut index = self.index.write().await;
let entries: Vec<_> = pending
.drain(..)
.map(|e| {
let tuple = (e.key.clone(), e.index, e.value.clone());
index.insert(e.key, (e.index, e.value));
tuple
})
.collect();
entries
}
}
pub struct ContentIndexStore {
conn: Arc<DatabaseConnection>,
}
impl ContentIndexStore {
pub fn new(conn: Arc<DatabaseConnection>) -> Self {
Self { conn }
}
pub async fn load(
&self,
index_name: &str,
origin: &str,
) -> Result<HashMap<String, (u64, Option<String>)>> {
let rows = content_index::Entity::find()
.filter(content_index::Column::IndexName.eq(index_name))
.filter(content_index::Column::Origin.eq(origin))
.all(&*self.conn)
.await?;
let mut map = HashMap::new();
for row in rows {
map.insert(row.key, (row.first_index as u64, row.value));
}
Ok(map)
}
pub async fn save(
&self,
index_name: &str,
origin: &str,
entries: &[(String, u64, Option<String>)],
) -> Result<()> {
if entries.is_empty() {
return Ok(());
}
let txn = self.conn.begin().await?;
for (key, first_index, value) in entries {
let model = content_index::ActiveModel {
id: ActiveValue::NotSet,
index_name: ActiveValue::Set(index_name.to_string()),
origin: ActiveValue::Set(origin.to_string()),
key: ActiveValue::Set(key.clone()),
first_index: ActiveValue::Set(*first_index as i64),
value: ActiveValue::Set(value.clone()),
};
content_index::Entity::insert(model)
.on_conflict(
sea_orm::sea_query::OnConflict::columns([
content_index::Column::IndexName,
content_index::Column::Origin,
content_index::Column::Key,
])
.do_nothing()
.to_owned(),
)
.exec(&txn)
.await?;
}
txn.commit().await?;
Ok(())
}
}
mod content_index {
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "content_index")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i64,
pub index_name: String,
pub origin: String,
pub key: String,
pub first_index: i64,
pub value: Option<String>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_content_index_basic() {
let index = ContentIndex::new("test");
assert!(index.is_empty().await);
assert!(!index.contains("key1").await);
index.stage("key1".to_string(), 0, None).await;
assert_eq!(index.pending_count().await, 1);
assert!(index.contains("key1").await);
let violation = index.check("key1", 1, None).await;
assert!(violation.is_some());
assert_eq!(violation.unwrap().kind, ViolationKind::DuplicateKey);
index.commit().await;
assert_eq!(index.pending_count().await, 0);
assert_eq!(index.len().await, 1);
assert!(index.contains("key1").await);
}
#[tokio::test]
async fn test_content_index_value_conflict() {
let index = ContentIndex::new("test");
index
.stage("filename.tar.gz".to_string(), 0, Some("hash1".to_string()))
.await;
index.commit().await;
let violation = index.check("filename.tar.gz", 1, Some("hash1")).await;
assert!(violation.is_some());
assert_eq!(violation.unwrap().kind, ViolationKind::DuplicateKey);
let violation = index.check("filename.tar.gz", 2, Some("hash2")).await;
assert!(violation.is_some());
let v = violation.unwrap();
assert_eq!(v.kind, ViolationKind::ConflictingValue);
assert_eq!(v.context, Some("hash1".to_string()));
}
#[tokio::test]
async fn test_content_index_rollback() {
let index = ContentIndex::new("test");
index.stage("key1".to_string(), 0, None).await;
index.stage("key2".to_string(), 1, None).await;
assert_eq!(index.pending_count().await, 2);
index.rollback().await;
assert_eq!(index.pending_count().await, 0);
assert!(index.is_empty().await);
}
}