Skip to main content

heeranjid_sqlx/
postgres.rs

1use heeranjid::Error;
2use sqlx::Executor;
3use sqlx::FromRow;
4
5fn map_sql_error(err: sqlx::Error) -> crate::GenerateError {
6    if let Some(db_err) = err.as_database_error() {
7        let code = db_err.code();
8        let message = db_err.message().to_string();
9        match code.as_deref() {
10            Some("50021") => return crate::GenerateError::LogicalDrift { message },
11            Some("50020") => return crate::GenerateError::ClockRollback { message },
12            Some("50022") => return crate::GenerateError::HardClockRollback { message },
13            _ => {}
14        }
15    }
16    crate::GenerateError::Database(err)
17}
18
19pub const SCHEMA_SQL: &str = include_str!("../sql/schema.sql");
20pub const SESSION_SQL: &str = include_str!("../sql/functions/session.sql");
21pub const GENERATE_HEERID_SQL: &str = include_str!("../sql/functions/generate_heerid.sql");
22pub const GENERATE_RANJID_SQL: &str = include_str!("../sql/functions/generate_ranjid.sql");
23pub const INSTALL_SQL: &str = concat!(
24    include_str!("../sql/schema.sql"),
25    "\n",
26    include_str!("../sql/functions/session.sql"),
27    "\n",
28    include_str!("../sql/functions/generate_heerid.sql"),
29    "\n",
30    include_str!("../sql/functions/generate_ranjid.sql"),
31);
32pub const CONFIGURE_SQL: &str = include_str!("../sql/functions/configure.sql");
33pub const FETCH_NODE_SQL: &str = include_str!("../sql/queries/fetch_node.sql");
34pub const FETCH_EPOCH_SQL: &str = include_str!("../sql/queries/fetch_epoch.sql");
35pub const SEED_SQL: &str = include_str!("../sql/seed.sql");
36pub const FETCH_ACTIVE_NODE_SQL: &str = include_str!("../sql/queries/fetch_active_node.sql");
37
38#[derive(Debug, Clone, PartialEq, Eq, FromRow)]
39pub struct HeerNode {
40    pub node_id: i32,
41    pub name: String,
42    pub description: Option<String>,
43    pub is_active: bool,
44}
45
46#[derive(Debug, Clone, PartialEq, Eq, FromRow)]
47pub struct HeerConfig {
48    pub epoch: sqlx::types::time::PrimitiveDateTime,
49}
50
51pub fn validate_heer_node_id(node_id: i32) -> Result<u16, Error> {
52    if !(0..=i32::from(heeranjid::HeerId::MAX_NODE_ID)).contains(&node_id) {
53        return Err(Error::NodeIdOutOfRange {
54            value: node_id.max(0) as u32,
55            bits: heeranjid::HEER_NODE_ID_BITS,
56        });
57    }
58
59    Ok(node_id as u16)
60}
61
62pub async fn install_schema<'e, E>(executor: E) -> Result<(), sqlx::Error>
63where
64    E: Executor<'e, Database = sqlx::Postgres>,
65{
66    sqlx::raw_sql(INSTALL_SQL).execute(executor).await?;
67    Ok(())
68}
69
70pub async fn install_configure<'e, E>(executor: E) -> Result<(), sqlx::Error>
71where
72    E: Executor<'e, Database = sqlx::Postgres>,
73{
74    sqlx::raw_sql(CONFIGURE_SQL).execute(executor).await?;
75    Ok(())
76}
77
78pub async fn fetch_node(
79    executor: impl Executor<'_, Database = sqlx::Postgres>,
80    node_id: u16,
81) -> Result<Option<HeerNode>, sqlx::Error> {
82    sqlx::query_as::<_, HeerNode>(FETCH_NODE_SQL)
83        .bind(i32::from(node_id))
84        .fetch_optional(executor)
85        .await
86}
87
88pub async fn fetch_epoch(
89    executor: impl Executor<'_, Database = sqlx::Postgres>,
90) -> Result<Option<sqlx::types::time::PrimitiveDateTime>, sqlx::Error> {
91    let record = sqlx::query_as::<_, HeerConfig>(FETCH_EPOCH_SQL)
92        .fetch_optional(executor)
93        .await?;
94
95    Ok(record.map(|row| row.epoch))
96}
97
98pub async fn fetch_active_node(
99    executor: impl Executor<'_, Database = sqlx::Postgres>,
100    node_id: u16,
101) -> Result<Option<HeerNode>, sqlx::Error> {
102    sqlx::query_as::<_, HeerNode>(FETCH_ACTIVE_NODE_SQL)
103        .bind(i32::from(node_id))
104        .fetch_optional(executor)
105        .await
106}
107
108pub async fn validate_startup(
109    executor: impl Executor<'_, Database = sqlx::Postgres>,
110    node_id: u16,
111) -> Result<HeerNode, crate::StartupError> {
112    let node = fetch_active_node(executor, node_id)
113        .await
114        .map_err(crate::StartupError::Database)?;
115
116    match node {
117        Some(node) => Ok(node),
118        None => Err(crate::StartupError::NodeNotActive(node_id)),
119    }
120}
121
122pub async fn validate_epoch(
123    executor: impl Executor<'_, Database = sqlx::Postgres>,
124) -> Result<sqlx::types::time::PrimitiveDateTime, crate::StartupError> {
125    let epoch = fetch_epoch(executor)
126        .await
127        .map_err(crate::StartupError::Database)?;
128
129    match epoch {
130        Some(epoch) => Ok(epoch),
131        None => Err(crate::StartupError::MissingEpoch),
132    }
133}
134
135pub async fn seed_default_node<'e, E>(executor: E) -> Result<(), sqlx::Error>
136where
137    E: Executor<'e, Database = sqlx::Postgres>,
138{
139    sqlx::raw_sql(SEED_SQL).execute(executor).await?;
140    Ok(())
141}
142
143pub async fn generate_heerid(
144    executor: impl Executor<'_, Database = sqlx::Postgres>,
145    node_id: u16,
146) -> Result<heeranjid::HeerId, crate::GenerateError> {
147    let raw: i64 = sqlx::query_scalar("SELECT generate_id($1)")
148        .bind(i32::from(node_id))
149        .fetch_one(executor)
150        .await
151        .map_err(map_sql_error)?;
152    heeranjid::HeerId::from_i64(raw).map_err(crate::GenerateError::InvalidHeerId)
153}
154
155pub async fn generate_ranjid(
156    executor: impl Executor<'_, Database = sqlx::Postgres>,
157    node_id: u16,
158) -> Result<heeranjid::RanjId, crate::GenerateError> {
159    let uuid: uuid::Uuid = sqlx::query_scalar("SELECT generate_ranjid($1)")
160        .bind(i32::from(node_id))
161        .fetch_one(executor)
162        .await
163        .map_err(map_sql_error)?;
164    heeranjid::RanjId::from_uuid(uuid).map_err(crate::GenerateError::InvalidRanjId)
165}
166
167pub async fn generate_heerids(
168    executor: impl Executor<'_, Database = sqlx::Postgres>,
169    node_id: u16,
170    count: i32,
171) -> Result<Vec<heeranjid::HeerId>, crate::GenerateError> {
172    let rows: Vec<i64> = sqlx::query_scalar("SELECT id FROM generate_ids($1, $2)")
173        .bind(i32::from(node_id))
174        .bind(count)
175        .fetch_all(executor)
176        .await
177        .map_err(map_sql_error)?;
178    rows.into_iter()
179        .map(|raw| heeranjid::HeerId::from_i64(raw).map_err(crate::GenerateError::InvalidHeerId))
180        .collect()
181}
182
183pub async fn set_ranj_node_id(
184    executor: impl Executor<'_, Database = sqlx::Postgres>,
185    node_id: u16,
186) -> Result<(), sqlx::Error> {
187    sqlx::query("SELECT set_heer_ranj_node_id($1)")
188        .bind(i32::from(node_id))
189        .execute(executor)
190        .await?;
191    Ok(())
192}
193
194pub async fn generate_ranjids(
195    executor: impl Executor<'_, Database = sqlx::Postgres>,
196    node_id: u16,
197    count: i32,
198) -> Result<Vec<heeranjid::RanjId>, crate::GenerateError> {
199    let rows: Vec<uuid::Uuid> = sqlx::query_scalar("SELECT id FROM generate_ranjids($1, $2)")
200        .bind(i32::from(node_id))
201        .bind(count)
202        .fetch_all(executor)
203        .await
204        .map_err(map_sql_error)?;
205    rows.into_iter()
206        .map(|uuid| heeranjid::RanjId::from_uuid(uuid).map_err(crate::GenerateError::InvalidRanjId))
207        .collect()
208}