Skip to main content

heeranjid_sqlx/
postgres.rs

1use heeranjid::Error;
2use sqlx::Executor;
3use sqlx::FromRow;
4
5pub const SCHEMA_SQL: &str = include_str!("../sql/schema.sql");
6pub const SESSION_SQL: &str = include_str!("../sql/functions/session.sql");
7pub const GENERATE_HEERID_SQL: &str = include_str!("../sql/functions/generate_heerid.sql");
8pub const GENERATE_RANJID_SQL: &str = include_str!("../sql/functions/generate_ranjid.sql");
9pub const INSTALL_SQL: &str = concat!(
10    include_str!("../sql/schema.sql"),
11    "\n",
12    include_str!("../sql/functions/session.sql"),
13    "\n",
14    include_str!("../sql/functions/generate_heerid.sql"),
15    "\n",
16    include_str!("../sql/functions/generate_ranjid.sql"),
17);
18pub const FETCH_NODE_SQL: &str = include_str!("../sql/queries/fetch_node.sql");
19pub const FETCH_EPOCH_SQL: &str = include_str!("../sql/queries/fetch_epoch.sql");
20pub const SEED_SQL: &str = include_str!("../sql/seed.sql");
21pub const FETCH_ACTIVE_NODE_SQL: &str = include_str!("../sql/queries/fetch_active_node.sql");
22
23#[derive(Debug, Clone, PartialEq, Eq, FromRow)]
24pub struct HeerNode {
25    pub node_id: i32,
26    pub name: String,
27    pub description: Option<String>,
28    pub is_active: bool,
29}
30
31#[derive(Debug, Clone, PartialEq, Eq, FromRow)]
32pub struct HeerConfig {
33    pub epoch: sqlx::types::time::PrimitiveDateTime,
34}
35
36pub fn validate_heer_node_id(node_id: i32) -> Result<u16, Error> {
37    if !(0..=i32::from(heeranjid::HeerId::MAX_NODE_ID)).contains(&node_id) {
38        return Err(Error::NodeIdOutOfRange {
39            value: node_id.max(0) as u32,
40            bits: heeranjid::HEER_NODE_ID_BITS,
41        });
42    }
43
44    Ok(node_id as u16)
45}
46
47pub async fn install_schema<'e, E>(executor: E) -> Result<(), sqlx::Error>
48where
49    E: Executor<'e, Database = sqlx::Postgres>,
50{
51    sqlx::raw_sql(INSTALL_SQL).execute(executor).await?;
52    Ok(())
53}
54
55pub async fn fetch_node(
56    executor: impl Executor<'_, Database = sqlx::Postgres>,
57    node_id: u16,
58) -> Result<Option<HeerNode>, sqlx::Error> {
59    sqlx::query_as::<_, HeerNode>(FETCH_NODE_SQL)
60        .bind(i32::from(node_id))
61        .fetch_optional(executor)
62        .await
63}
64
65pub async fn fetch_epoch(
66    executor: impl Executor<'_, Database = sqlx::Postgres>,
67) -> Result<Option<sqlx::types::time::PrimitiveDateTime>, sqlx::Error> {
68    let record = sqlx::query_as::<_, HeerConfig>(FETCH_EPOCH_SQL)
69        .fetch_optional(executor)
70        .await?;
71
72    Ok(record.map(|row| row.epoch))
73}
74
75pub async fn fetch_active_node(
76    executor: impl Executor<'_, Database = sqlx::Postgres>,
77    node_id: u16,
78) -> Result<Option<HeerNode>, sqlx::Error> {
79    sqlx::query_as::<_, HeerNode>(FETCH_ACTIVE_NODE_SQL)
80        .bind(i32::from(node_id))
81        .fetch_optional(executor)
82        .await
83}
84
85pub async fn validate_startup(
86    executor: impl Executor<'_, Database = sqlx::Postgres>,
87    node_id: u16,
88) -> Result<HeerNode, crate::StartupError> {
89    let node = fetch_active_node(executor, node_id)
90        .await
91        .map_err(crate::StartupError::Database)?;
92
93    match node {
94        Some(node) => Ok(node),
95        None => Err(crate::StartupError::NodeNotActive(node_id)),
96    }
97}
98
99pub async fn validate_epoch(
100    executor: impl Executor<'_, Database = sqlx::Postgres>,
101) -> Result<sqlx::types::time::PrimitiveDateTime, crate::StartupError> {
102    let epoch = fetch_epoch(executor)
103        .await
104        .map_err(crate::StartupError::Database)?;
105
106    match epoch {
107        Some(epoch) => Ok(epoch),
108        None => Err(crate::StartupError::MissingEpoch),
109    }
110}
111
112pub async fn seed_default_node<'e, E>(executor: E) -> Result<(), sqlx::Error>
113where
114    E: Executor<'e, Database = sqlx::Postgres>,
115{
116    sqlx::raw_sql(SEED_SQL).execute(executor).await?;
117    Ok(())
118}
119
120pub async fn generate_heerid(
121    executor: impl Executor<'_, Database = sqlx::Postgres>,
122    node_id: u16,
123) -> Result<heeranjid::HeerId, crate::GenerateError> {
124    let raw: i64 = sqlx::query_scalar("SELECT generate_id($1)")
125        .bind(i32::from(node_id))
126        .fetch_one(executor)
127        .await?;
128    heeranjid::HeerId::from_i64(raw).map_err(crate::GenerateError::InvalidHeerId)
129}
130
131pub async fn generate_ranjid(
132    executor: impl Executor<'_, Database = sqlx::Postgres>,
133    node_id: u16,
134) -> Result<heeranjid::RanjId, crate::GenerateError> {
135    let uuid: uuid::Uuid = sqlx::query_scalar("SELECT generate_ranjid($1)")
136        .bind(i32::from(node_id))
137        .fetch_one(executor)
138        .await?;
139    heeranjid::RanjId::from_uuid(uuid).map_err(crate::GenerateError::InvalidRanjId)
140}
141
142pub async fn generate_heerids(
143    executor: impl Executor<'_, Database = sqlx::Postgres>,
144    node_id: u16,
145    count: i32,
146) -> Result<Vec<heeranjid::HeerId>, crate::GenerateError> {
147    let rows: Vec<i64> = sqlx::query_scalar("SELECT id FROM generate_ids($1, $2)")
148        .bind(i32::from(node_id))
149        .bind(count)
150        .fetch_all(executor)
151        .await?;
152    rows.into_iter()
153        .map(|raw| heeranjid::HeerId::from_i64(raw).map_err(crate::GenerateError::InvalidHeerId))
154        .collect()
155}
156
157pub async fn set_ranj_node_id(
158    executor: impl Executor<'_, Database = sqlx::Postgres>,
159    node_id: u16,
160) -> Result<(), sqlx::Error> {
161    sqlx::query("SELECT set_heer_ranj_node_id($1)")
162        .bind(i32::from(node_id))
163        .execute(executor)
164        .await?;
165    Ok(())
166}
167
168pub async fn generate_ranjids(
169    executor: impl Executor<'_, Database = sqlx::Postgres>,
170    node_id: u16,
171    count: i32,
172) -> Result<Vec<heeranjid::RanjId>, crate::GenerateError> {
173    let rows: Vec<uuid::Uuid> = sqlx::query_scalar("SELECT id FROM generate_ranjids($1, $2)")
174        .bind(i32::from(node_id))
175        .bind(count)
176        .fetch_all(executor)
177        .await?;
178    rows.into_iter()
179        .map(|uuid| heeranjid::RanjId::from_uuid(uuid).map_err(crate::GenerateError::InvalidRanjId))
180        .collect()
181}