Skip to main content

heeranjid_sqlx/
postgres.rs

1use heeranjid::Error;
2use sqlx::Executor;
3use sqlx::FromRow;
4
5pub const SCHEMA_SQL: &str = include_str!("../sql/schema.sql");
6pub const SESSION_SQL: &str = include_str!("../sql/functions/session.sql");
7pub const GENERATE_HEERID_SQL: &str =
8    include_str!("../sql/functions/generate_heerid.sql");
9pub const GENERATE_RANJID_SQL: &str =
10    include_str!("../sql/functions/generate_ranjid.sql");
11pub const INSTALL_SQL: &str = concat!(
12    include_str!("../sql/schema.sql"),
13    "\n",
14    include_str!("../sql/functions/session.sql"),
15    "\n",
16    include_str!("../sql/functions/generate_heerid.sql"),
17    "\n",
18    include_str!("../sql/functions/generate_ranjid.sql"),
19);
20pub const FETCH_NODE_SQL: &str = include_str!("../sql/queries/fetch_node.sql");
21pub const FETCH_EPOCH_SQL: &str = include_str!("../sql/queries/fetch_epoch.sql");
22pub const SEED_SQL: &str = include_str!("../sql/seed.sql");
23pub const FETCH_ACTIVE_NODE_SQL: &str =
24    include_str!("../sql/queries/fetch_active_node.sql");
25
26#[derive(Debug, Clone, PartialEq, Eq, FromRow)]
27pub struct HeerNode {
28    pub node_id: i32,
29    pub name: String,
30    pub description: Option<String>,
31    pub is_active: bool,
32}
33
34#[derive(Debug, Clone, PartialEq, Eq, FromRow)]
35pub struct HeerConfig {
36    pub epoch: sqlx::types::time::PrimitiveDateTime,
37}
38
39pub fn validate_heer_node_id(node_id: i32) -> Result<u16, Error> {
40    if !(0..=i32::from(heeranjid::HeerId::MAX_NODE_ID)).contains(&node_id) {
41        return Err(Error::NodeIdOutOfRange {
42            value: node_id.max(0) as u32,
43            bits: heeranjid::HEER_NODE_ID_BITS,
44        });
45    }
46
47    Ok(node_id as u16)
48}
49
50pub async fn install_schema<'e, E>(executor: E) -> Result<(), sqlx::Error>
51where
52    E: Executor<'e, Database = sqlx::Postgres>,
53{
54    sqlx::raw_sql(INSTALL_SQL).execute(executor).await?;
55    Ok(())
56}
57
58pub async fn fetch_node(
59    executor: impl Executor<'_, Database = sqlx::Postgres>,
60    node_id: u16,
61) -> Result<Option<HeerNode>, sqlx::Error> {
62    sqlx::query_as::<_, HeerNode>(FETCH_NODE_SQL)
63        .bind(i32::from(node_id))
64        .fetch_optional(executor)
65        .await
66}
67
68pub async fn fetch_epoch(
69    executor: impl Executor<'_, Database = sqlx::Postgres>,
70) -> Result<Option<sqlx::types::time::PrimitiveDateTime>, sqlx::Error> {
71    let record = sqlx::query_as::<_, HeerConfig>(FETCH_EPOCH_SQL)
72        .fetch_optional(executor)
73        .await?;
74
75    Ok(record.map(|row| row.epoch))
76}
77
78pub async fn fetch_active_node(
79    executor: impl Executor<'_, Database = sqlx::Postgres>,
80    node_id: u16,
81) -> Result<Option<HeerNode>, sqlx::Error> {
82    sqlx::query_as::<_, HeerNode>(FETCH_ACTIVE_NODE_SQL)
83        .bind(i32::from(node_id))
84        .fetch_optional(executor)
85        .await
86}
87
88pub async fn validate_startup(
89    executor: impl Executor<'_, Database = sqlx::Postgres>,
90    node_id: u16,
91) -> Result<HeerNode, crate::StartupError> {
92    let node = fetch_active_node(executor, node_id)
93        .await
94        .map_err(crate::StartupError::Database)?;
95
96    match node {
97        Some(node) => Ok(node),
98        None => Err(crate::StartupError::NodeNotActive(node_id)),
99    }
100}
101
102pub async fn validate_epoch(
103    executor: impl Executor<'_, Database = sqlx::Postgres>,
104) -> Result<sqlx::types::time::PrimitiveDateTime, crate::StartupError> {
105    let epoch = fetch_epoch(executor)
106        .await
107        .map_err(crate::StartupError::Database)?;
108
109    match epoch {
110        Some(epoch) => Ok(epoch),
111        None => Err(crate::StartupError::MissingEpoch),
112    }
113}
114
115pub async fn seed_default_node<'e, E>(executor: E) -> Result<(), sqlx::Error>
116where
117    E: Executor<'e, Database = sqlx::Postgres>,
118{
119    sqlx::raw_sql(SEED_SQL).execute(executor).await?;
120    Ok(())
121}
122
123pub async fn generate_heerid(
124    executor: impl Executor<'_, Database = sqlx::Postgres>,
125    node_id: u16,
126) -> Result<heeranjid::HeerId, crate::GenerateError> {
127    let raw: i64 = sqlx::query_scalar("SELECT generate_id($1)")
128        .bind(i32::from(node_id))
129        .fetch_one(executor)
130        .await?;
131    heeranjid::HeerId::from_i64(raw).map_err(crate::GenerateError::InvalidHeerId)
132}
133
134pub async fn generate_ranjid(
135    executor: impl Executor<'_, Database = sqlx::Postgres>,
136    node_id: u16,
137) -> Result<heeranjid::RanjId, crate::GenerateError> {
138    let uuid: uuid::Uuid = sqlx::query_scalar("SELECT generate_ranjid($1)")
139        .bind(i32::from(node_id))
140        .fetch_one(executor)
141        .await?;
142    heeranjid::RanjId::from_uuid(uuid).map_err(crate::GenerateError::InvalidRanjId)
143}
144
145pub async fn generate_heerids(
146    executor: impl Executor<'_, Database = sqlx::Postgres>,
147    node_id: u16,
148    count: i32,
149) -> Result<Vec<heeranjid::HeerId>, crate::GenerateError> {
150    let rows: Vec<i64> = sqlx::query_scalar("SELECT id FROM generate_ids($1, $2)")
151        .bind(i32::from(node_id))
152        .bind(count)
153        .fetch_all(executor)
154        .await?;
155    rows.into_iter()
156        .map(|raw| heeranjid::HeerId::from_i64(raw).map_err(crate::GenerateError::InvalidHeerId))
157        .collect()
158}
159
160pub async fn set_ranj_node_id(
161    executor: impl Executor<'_, Database = sqlx::Postgres>,
162    node_id: u16,
163) -> Result<(), sqlx::Error> {
164    sqlx::query("SELECT set_heer_ranj_node_id($1)")
165        .bind(i32::from(node_id))
166        .execute(executor)
167        .await?;
168    Ok(())
169}
170
171pub async fn generate_ranjids(
172    executor: impl Executor<'_, Database = sqlx::Postgres>,
173    node_id: u16,
174    count: i32,
175) -> Result<Vec<heeranjid::RanjId>, crate::GenerateError> {
176    let rows: Vec<uuid::Uuid> = sqlx::query_scalar("SELECT id FROM generate_ranjids($1, $2)")
177        .bind(i32::from(node_id))
178        .bind(count)
179        .fetch_all(executor)
180        .await?;
181    rows.into_iter()
182        .map(|uuid| heeranjid::RanjId::from_uuid(uuid).map_err(crate::GenerateError::InvalidRanjId))
183        .collect()
184}