use std::sync::Arc;
use serde::{Deserialize, Serialize};
use sqlx::Row;
use strum_macros::{AsRefStr, Display, EnumString};
use crate::storage::StorageError;
use crate::storage::db::SqlitePool;
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, EnumString, Display, AsRefStr,
)]
#[serde(rename_all = "lowercase")]
#[strum(serialize_all = "lowercase", ascii_case_insensitive)]
pub enum CollectorType {
Tcp,
Ping,
Http,
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, EnumString, Display, AsRefStr,
)]
#[serde(rename_all = "lowercase")]
#[strum(serialize_all = "lowercase", ascii_case_insensitive)]
pub enum CollectorSource {
Config,
Api,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CollectorRecord {
pub id: Option<i64>,
pub collector_type: CollectorType,
pub name: String,
pub source: CollectorSource,
pub enabled: bool,
#[serde(rename = "group")]
pub group_name: String,
pub config: serde_json::Value,
pub created_at: i64,
pub updated_at: i64,
}
impl CollectorRecord {
pub fn from_config(
collector_type: CollectorType,
name: impl Into<String>,
enabled: bool,
group_name: impl Into<String>,
config: serde_json::Value,
) -> Self {
let now = chrono::Utc::now().timestamp_millis();
Self {
id: None,
collector_type,
name: name.into(),
source: CollectorSource::Config,
enabled,
group_name: group_name.into(),
config,
created_at: now,
updated_at: now,
}
}
pub fn from_api(
collector_type: CollectorType,
name: impl Into<String>,
enabled: bool,
group_name: impl Into<String>,
config: serde_json::Value,
) -> Self {
let now = chrono::Utc::now().timestamp_millis();
Self {
id: None,
collector_type,
name: name.into(),
source: CollectorSource::Api,
enabled,
group_name: group_name.into(),
config,
created_at: now,
updated_at: now,
}
}
}
#[derive(Debug, Default)]
pub struct SyncResult {
pub added: usize,
pub updated: usize,
pub deleted: usize,
}
#[derive(Clone)]
pub struct CollectorStore {
pool: Arc<SqlitePool>,
}
impl CollectorStore {
pub fn new(pool: Arc<SqlitePool>) -> Self {
Self { pool }
}
pub async fn upsert(&self, record: &CollectorRecord) -> Result<i64, StorageError> {
let now = chrono::Utc::now().timestamp_millis();
let config_json = record.config.to_string();
let existing: Option<(i64,)> =
sqlx::query_as("SELECT id FROM collectors WHERE type = ? AND name = ?")
.bind(record.collector_type.as_ref())
.bind(&record.name)
.fetch_optional(self.pool.inner())
.await?;
if let Some((id,)) = existing {
sqlx::query(
"UPDATE collectors SET enabled = ?, group_name = ?, config = ?, updated_at = ?
WHERE id = ?",
)
.bind(record.enabled)
.bind(&record.group_name)
.bind(&config_json)
.bind(now)
.bind(id)
.execute(self.pool.inner())
.await?;
Ok(id)
} else {
let result = sqlx::query(
"INSERT INTO collectors (type, name, source, enabled, group_name, config, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
)
.bind(record.collector_type.as_ref())
.bind(&record.name)
.bind(record.source.as_ref())
.bind(record.enabled)
.bind(&record.group_name)
.bind(&config_json)
.bind(now)
.bind(now)
.execute(self.pool.inner())
.await?;
Ok(result.last_insert_rowid())
}
}
pub async fn insert_if_not_exists(
&self,
record: &CollectorRecord,
) -> Result<Option<i64>, StorageError> {
let existing: Option<(i64,)> =
sqlx::query_as("SELECT id FROM collectors WHERE type = ? AND name = ?")
.bind(record.collector_type.as_ref())
.bind(&record.name)
.fetch_optional(self.pool.inner())
.await?;
if existing.is_some() {
return Ok(None);
}
let now = chrono::Utc::now().timestamp_millis();
let config_json = record.config.to_string();
let result = sqlx::query(
"INSERT INTO collectors (type, name, source, enabled, group_name, config, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
)
.bind(record.collector_type.as_ref())
.bind(&record.name)
.bind(record.source.as_ref())
.bind(record.enabled)
.bind(&record.group_name)
.bind(&config_json)
.bind(now)
.bind(now)
.execute(self.pool.inner())
.await?;
Ok(Some(result.last_insert_rowid()))
}
pub async fn delete(
&self,
collector_type: CollectorType,
name: &str,
) -> Result<bool, StorageError> {
let result = sqlx::query("DELETE FROM collectors WHERE type = ? AND name = ?")
.bind(collector_type.as_ref())
.bind(name)
.execute(self.pool.inner())
.await?;
Ok(result.rows_affected() > 0)
}
pub async fn list_by_source(
&self,
source: CollectorSource,
) -> Result<Vec<CollectorRecord>, StorageError> {
let rows = sqlx::query(
"SELECT id, type, name, source, enabled, group_name, config, created_at, updated_at
FROM collectors WHERE source = ? ORDER BY type, name",
)
.bind(source.as_ref())
.fetch_all(self.pool.inner())
.await?;
let records = rows
.into_iter()
.map(|row| {
let type_str: String = row.get(1);
let source_str: String = row.get(3);
let config_str: String = row.get(6);
CollectorRecord {
id: Some(row.get(0)),
collector_type: type_str.parse().unwrap_or(CollectorType::Tcp),
name: row.get(2),
source: source_str.parse().unwrap_or(CollectorSource::Config),
enabled: row.get::<i32, _>(4) != 0,
group_name: row.get(5),
config: serde_json::from_str(&config_str).unwrap_or_default(),
created_at: row.get(7),
updated_at: row.get(8),
}
})
.collect();
Ok(records)
}
pub async fn list_all(&self) -> Result<Vec<CollectorRecord>, StorageError> {
let rows = sqlx::query(
"SELECT id, type, name, source, enabled, group_name, config, created_at, updated_at
FROM collectors ORDER BY type, name",
)
.fetch_all(self.pool.inner())
.await?;
let records = rows
.into_iter()
.map(|row| {
let type_str: String = row.get(1);
let source_str: String = row.get(3);
let config_str: String = row.get(6);
CollectorRecord {
id: Some(row.get(0)),
collector_type: type_str.parse().unwrap_or(CollectorType::Tcp),
name: row.get(2),
source: source_str.parse().unwrap_or(CollectorSource::Config),
enabled: row.get::<i32, _>(4) != 0,
group_name: row.get(5),
config: serde_json::from_str(&config_str).unwrap_or_default(),
created_at: row.get(7),
updated_at: row.get(8),
}
})
.collect();
Ok(records)
}
pub async fn list_paginated(
&self,
page: u32,
page_size: u32,
) -> Result<(Vec<CollectorRecord>, u64), StorageError> {
let offset = (page.saturating_sub(1)) * page_size;
let count_row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM collectors")
.fetch_one(self.pool.inner())
.await?;
let total_count = count_row.0 as u64;
let rows = sqlx::query(
"SELECT id, type, name, source, enabled, group_name, config, created_at, updated_at
FROM collectors ORDER BY type, name LIMIT ? OFFSET ?",
)
.bind(page_size as i64)
.bind(offset as i64)
.fetch_all(self.pool.inner())
.await?;
let records = rows
.into_iter()
.map(|row| {
let type_str: String = row.get(1);
let source_str: String = row.get(3);
let config_str: String = row.get(6);
CollectorRecord {
id: Some(row.get(0)),
collector_type: type_str.parse().unwrap_or(CollectorType::Tcp),
name: row.get(2),
source: source_str.parse().unwrap_or(CollectorSource::Config),
enabled: row.get::<i32, _>(4) != 0,
group_name: row.get(5),
config: serde_json::from_str(&config_str).unwrap_or_default(),
created_at: row.get(7),
updated_at: row.get(8),
}
})
.collect();
Ok((records, total_count))
}
pub async fn get(
&self,
collector_type: CollectorType,
name: &str,
) -> Result<Option<CollectorRecord>, StorageError> {
let row = sqlx::query(
"SELECT id, type, name, source, enabled, group_name, config, created_at, updated_at
FROM collectors WHERE type = ? AND name = ?",
)
.bind(collector_type.as_ref())
.bind(name)
.fetch_optional(self.pool.inner())
.await?;
Ok(row.map(|row| {
let type_str: String = row.get(1);
let source_str: String = row.get(3);
let config_str: String = row.get(6);
CollectorRecord {
id: Some(row.get(0)),
collector_type: type_str.parse().unwrap_or(CollectorType::Tcp),
name: row.get(2),
source: source_str.parse().unwrap_or(CollectorSource::Config),
enabled: row.get::<i32, _>(4) != 0,
group_name: row.get(5),
config: serde_json::from_str(&config_str).unwrap_or_default(),
created_at: row.get(7),
updated_at: row.get(8),
}
}))
}
pub async fn sync_from_config(
&self,
configs: Vec<CollectorRecord>,
) -> Result<SyncResult, StorageError> {
let mut result = SyncResult::default();
let existing = self.list_by_source(CollectorSource::Config).await?;
let existing_keys: std::collections::HashSet<_> = existing
.iter()
.map(|r| (r.collector_type, r.name.clone()))
.collect();
let new_keys: std::collections::HashSet<_> = configs
.iter()
.map(|r| (r.collector_type, r.name.clone()))
.collect();
for config in &configs {
let existed = existing_keys.contains(&(config.collector_type, config.name.clone()));
self.upsert(config).await?;
if existed {
result.updated += 1;
} else {
result.added += 1;
}
}
for (collector_type, name) in existing_keys.difference(&new_keys) {
self.delete(*collector_type, name).await?;
result.deleted += 1;
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::db::SqlitePool;
use crate::storage::schema::init_schema;
use std::sync::Arc;
async fn create_test_pool() -> Arc<SqlitePool> {
let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
init_schema(pool.inner()).await.unwrap();
Arc::new(pool)
}
#[tokio::test]
async fn test_collector_crud() {
let pool = create_test_pool().await;
let store = CollectorStore::new(pool);
let record = CollectorRecord::from_config(
CollectorType::Tcp,
"test-tcp",
true,
"default",
serde_json::json!({"host": "127.0.0.1", "port": 6379}),
);
let id = store.upsert(&record).await.unwrap();
assert!(id > 0);
let fetched = store.get(CollectorType::Tcp, "test-tcp").await.unwrap();
assert!(fetched.is_some());
let fetched = fetched.unwrap();
assert_eq!(fetched.name, "test-tcp");
assert!(fetched.enabled);
let mut updated = record.clone();
updated.enabled = false;
store.upsert(&updated).await.unwrap();
let fetched = store
.get(CollectorType::Tcp, "test-tcp")
.await
.unwrap()
.unwrap();
assert!(!fetched.enabled);
let deleted = store.delete(CollectorType::Tcp, "test-tcp").await.unwrap();
assert!(deleted);
let fetched = store.get(CollectorType::Tcp, "test-tcp").await.unwrap();
assert!(fetched.is_none());
}
#[tokio::test]
async fn test_collector_sync() {
let pool = create_test_pool().await;
let store = CollectorStore::new(pool);
let configs = vec![
CollectorRecord::from_config(
CollectorType::Tcp,
"tcp-1",
true,
"default",
serde_json::json!({}),
),
CollectorRecord::from_config(
CollectorType::Ping,
"ping-1",
true,
"default",
serde_json::json!({}),
),
];
let result = store.sync_from_config(configs).await.unwrap();
assert_eq!(result.added, 2);
assert_eq!(result.updated, 0);
assert_eq!(result.deleted, 0);
let configs = vec![
CollectorRecord::from_config(
CollectorType::Tcp,
"tcp-1",
true,
"default",
serde_json::json!({}),
),
CollectorRecord::from_config(
CollectorType::Http,
"http-1",
true,
"default",
serde_json::json!({}),
),
];
let result = store.sync_from_config(configs).await.unwrap();
assert_eq!(result.added, 1);
assert_eq!(result.updated, 1);
assert_eq!(result.deleted, 1);
let all = store.list_all().await.unwrap();
assert_eq!(all.len(), 2);
}
#[tokio::test]
async fn test_collector_type_enum() {
use std::str::FromStr;
assert_eq!(CollectorType::from_str("tcp").unwrap(), CollectorType::Tcp);
assert_eq!(
CollectorType::from_str("PING").unwrap(),
CollectorType::Ping
);
let http_ref: &str = CollectorType::Http.as_ref();
assert_eq!(http_ref, "http");
}
}