Skip to main content

heeranjid_sqlx/
postgres.rs

1use heeranjid::Error;
2use sqlx::Executor;
3use sqlx::FromRow;
4
5fn map_sql_error(err: sqlx::Error) -> crate::GenerateError {
6    if let Some(db_err) = err.as_database_error() {
7        let code = db_err.code();
8        let message = db_err.message().to_string();
9        match code.as_deref() {
10            Some("50021") => return crate::GenerateError::LogicalDrift { message },
11            Some("50020") => return crate::GenerateError::ClockRollback { message },
12            Some("50022") => return crate::GenerateError::HardClockRollback { message },
13            _ => {}
14        }
15    }
16    crate::GenerateError::Database(err)
17}
18
19pub const SCHEMA_SQL: &str = include_str!("../sql/schema.sql");
20pub const SESSION_SQL: &str = include_str!("../sql/functions/session.sql");
21pub const GENERATE_HEERID_SQL: &str = include_str!("../sql/functions/generate_heerid.sql");
22pub const GENERATE_RANJID_SQL: &str = include_str!("../sql/functions/generate_ranjid.sql");
23pub const INSTALL_SQL: &str = concat!(
24    include_str!("../sql/schema.sql"),
25    "\n",
26    include_str!("../sql/functions/session.sql"),
27    "\n",
28    include_str!("../sql/functions/generate_heerid.sql"),
29    "\n",
30    include_str!("../sql/functions/generate_ranjid.sql"),
31);
32pub const FETCH_NODE_SQL: &str = include_str!("../sql/queries/fetch_node.sql");
33pub const FETCH_EPOCH_SQL: &str = include_str!("../sql/queries/fetch_epoch.sql");
34pub const SEED_SQL: &str = include_str!("../sql/seed.sql");
35pub const FETCH_ACTIVE_NODE_SQL: &str = include_str!("../sql/queries/fetch_active_node.sql");
36
37#[derive(Debug, Clone, PartialEq, Eq, FromRow)]
38pub struct HeerNode {
39    pub node_id: i32,
40    pub name: String,
41    pub description: Option<String>,
42    pub is_active: bool,
43}
44
45#[derive(Debug, Clone, PartialEq, Eq, FromRow)]
46pub struct HeerConfig {
47    pub epoch: sqlx::types::time::PrimitiveDateTime,
48}
49
50pub fn validate_heer_node_id(node_id: i32) -> Result<u16, Error> {
51    if !(0..=i32::from(heeranjid::HeerId::MAX_NODE_ID)).contains(&node_id) {
52        return Err(Error::NodeIdOutOfRange {
53            value: node_id.max(0) as u32,
54            bits: heeranjid::HEER_NODE_ID_BITS,
55        });
56    }
57
58    Ok(node_id as u16)
59}
60
61pub async fn install_schema<'e, E>(executor: E) -> Result<(), sqlx::Error>
62where
63    E: Executor<'e, Database = sqlx::Postgres>,
64{
65    sqlx::raw_sql(INSTALL_SQL).execute(executor).await?;
66    Ok(())
67}
68
69pub async fn fetch_node(
70    executor: impl Executor<'_, Database = sqlx::Postgres>,
71    node_id: u16,
72) -> Result<Option<HeerNode>, sqlx::Error> {
73    sqlx::query_as::<_, HeerNode>(FETCH_NODE_SQL)
74        .bind(i32::from(node_id))
75        .fetch_optional(executor)
76        .await
77}
78
79pub async fn fetch_epoch(
80    executor: impl Executor<'_, Database = sqlx::Postgres>,
81) -> Result<Option<sqlx::types::time::PrimitiveDateTime>, sqlx::Error> {
82    let record = sqlx::query_as::<_, HeerConfig>(FETCH_EPOCH_SQL)
83        .fetch_optional(executor)
84        .await?;
85
86    Ok(record.map(|row| row.epoch))
87}
88
89pub async fn fetch_active_node(
90    executor: impl Executor<'_, Database = sqlx::Postgres>,
91    node_id: u16,
92) -> Result<Option<HeerNode>, sqlx::Error> {
93    sqlx::query_as::<_, HeerNode>(FETCH_ACTIVE_NODE_SQL)
94        .bind(i32::from(node_id))
95        .fetch_optional(executor)
96        .await
97}
98
99pub async fn validate_startup(
100    executor: impl Executor<'_, Database = sqlx::Postgres>,
101    node_id: u16,
102) -> Result<HeerNode, crate::StartupError> {
103    let node = fetch_active_node(executor, node_id)
104        .await
105        .map_err(crate::StartupError::Database)?;
106
107    match node {
108        Some(node) => Ok(node),
109        None => Err(crate::StartupError::NodeNotActive(node_id)),
110    }
111}
112
113pub async fn validate_epoch(
114    executor: impl Executor<'_, Database = sqlx::Postgres>,
115) -> Result<sqlx::types::time::PrimitiveDateTime, crate::StartupError> {
116    let epoch = fetch_epoch(executor)
117        .await
118        .map_err(crate::StartupError::Database)?;
119
120    match epoch {
121        Some(epoch) => Ok(epoch),
122        None => Err(crate::StartupError::MissingEpoch),
123    }
124}
125
126pub async fn seed_default_node<'e, E>(executor: E) -> Result<(), sqlx::Error>
127where
128    E: Executor<'e, Database = sqlx::Postgres>,
129{
130    sqlx::raw_sql(SEED_SQL).execute(executor).await?;
131    Ok(())
132}
133
134pub async fn generate_heerid(
135    executor: impl Executor<'_, Database = sqlx::Postgres>,
136    node_id: u16,
137) -> Result<heeranjid::HeerId, crate::GenerateError> {
138    let raw: i64 = sqlx::query_scalar("SELECT generate_id($1)")
139        .bind(i32::from(node_id))
140        .fetch_one(executor)
141        .await
142        .map_err(map_sql_error)?;
143    heeranjid::HeerId::from_i64(raw).map_err(crate::GenerateError::InvalidHeerId)
144}
145
146pub async fn generate_ranjid(
147    executor: impl Executor<'_, Database = sqlx::Postgres>,
148    node_id: u16,
149) -> Result<heeranjid::RanjId, crate::GenerateError> {
150    let uuid: uuid::Uuid = sqlx::query_scalar("SELECT generate_ranjid($1)")
151        .bind(i32::from(node_id))
152        .fetch_one(executor)
153        .await
154        .map_err(map_sql_error)?;
155    heeranjid::RanjId::from_uuid(uuid).map_err(crate::GenerateError::InvalidRanjId)
156}
157
158pub async fn generate_heerids(
159    executor: impl Executor<'_, Database = sqlx::Postgres>,
160    node_id: u16,
161    count: i32,
162) -> Result<Vec<heeranjid::HeerId>, crate::GenerateError> {
163    let rows: Vec<i64> = sqlx::query_scalar("SELECT id FROM generate_ids($1, $2)")
164        .bind(i32::from(node_id))
165        .bind(count)
166        .fetch_all(executor)
167        .await
168        .map_err(map_sql_error)?;
169    rows.into_iter()
170        .map(|raw| heeranjid::HeerId::from_i64(raw).map_err(crate::GenerateError::InvalidHeerId))
171        .collect()
172}
173
174pub async fn set_ranj_node_id(
175    executor: impl Executor<'_, Database = sqlx::Postgres>,
176    node_id: u16,
177) -> Result<(), sqlx::Error> {
178    sqlx::query("SELECT set_heer_ranj_node_id($1)")
179        .bind(i32::from(node_id))
180        .execute(executor)
181        .await?;
182    Ok(())
183}
184
185pub async fn generate_ranjids(
186    executor: impl Executor<'_, Database = sqlx::Postgres>,
187    node_id: u16,
188    count: i32,
189) -> Result<Vec<heeranjid::RanjId>, crate::GenerateError> {
190    let rows: Vec<uuid::Uuid> = sqlx::query_scalar("SELECT id FROM generate_ranjids($1, $2)")
191        .bind(i32::from(node_id))
192        .bind(count)
193        .fetch_all(executor)
194        .await
195        .map_err(map_sql_error)?;
196    rows.into_iter()
197        .map(|uuid| heeranjid::RanjId::from_uuid(uuid).map_err(crate::GenerateError::InvalidRanjId))
198        .collect()
199}