heeranjid_sqlx/
postgres.rs1use heeranjid::Error;
2use sqlx::Executor;
3use sqlx::FromRow;
4
5fn map_sql_error(err: sqlx::Error) -> crate::GenerateError {
6 if let Some(db_err) = err.as_database_error() {
7 let code = db_err.code();
8 let message = db_err.message().to_string();
9 match code.as_deref() {
10 Some("50021") => return crate::GenerateError::LogicalDrift { message },
11 Some("50020") => return crate::GenerateError::ClockRollback { message },
12 Some("50022") => return crate::GenerateError::HardClockRollback { message },
13 _ => {}
14 }
15 }
16 crate::GenerateError::Database(err)
17}
18
19pub const SCHEMA_SQL: &str = include_str!("../sql/schema.sql");
20pub const SESSION_SQL: &str = include_str!("../sql/functions/session.sql");
21pub const GENERATE_HEERID_SQL: &str = include_str!("../sql/functions/generate_heerid.sql");
22pub const GENERATE_RANJID_SQL: &str = include_str!("../sql/functions/generate_ranjid.sql");
23pub const INSTALL_SQL: &str = concat!(
24 include_str!("../sql/schema.sql"),
25 "\n",
26 include_str!("../sql/functions/session.sql"),
27 "\n",
28 include_str!("../sql/functions/generate_heerid.sql"),
29 "\n",
30 include_str!("../sql/functions/generate_ranjid.sql"),
31);
32pub const FETCH_NODE_SQL: &str = include_str!("../sql/queries/fetch_node.sql");
33pub const FETCH_EPOCH_SQL: &str = include_str!("../sql/queries/fetch_epoch.sql");
34pub const SEED_SQL: &str = include_str!("../sql/seed.sql");
35pub const FETCH_ACTIVE_NODE_SQL: &str = include_str!("../sql/queries/fetch_active_node.sql");
36
37#[derive(Debug, Clone, PartialEq, Eq, FromRow)]
38pub struct HeerNode {
39 pub node_id: i32,
40 pub name: String,
41 pub description: Option<String>,
42 pub is_active: bool,
43}
44
45#[derive(Debug, Clone, PartialEq, Eq, FromRow)]
46pub struct HeerConfig {
47 pub epoch: sqlx::types::time::PrimitiveDateTime,
48}
49
50pub fn validate_heer_node_id(node_id: i32) -> Result<u16, Error> {
51 if !(0..=i32::from(heeranjid::HeerId::MAX_NODE_ID)).contains(&node_id) {
52 return Err(Error::NodeIdOutOfRange {
53 value: node_id.max(0) as u32,
54 bits: heeranjid::HEER_NODE_ID_BITS,
55 });
56 }
57
58 Ok(node_id as u16)
59}
60
61pub async fn install_schema<'e, E>(executor: E) -> Result<(), sqlx::Error>
62where
63 E: Executor<'e, Database = sqlx::Postgres>,
64{
65 sqlx::raw_sql(INSTALL_SQL).execute(executor).await?;
66 Ok(())
67}
68
69pub async fn fetch_node(
70 executor: impl Executor<'_, Database = sqlx::Postgres>,
71 node_id: u16,
72) -> Result<Option<HeerNode>, sqlx::Error> {
73 sqlx::query_as::<_, HeerNode>(FETCH_NODE_SQL)
74 .bind(i32::from(node_id))
75 .fetch_optional(executor)
76 .await
77}
78
79pub async fn fetch_epoch(
80 executor: impl Executor<'_, Database = sqlx::Postgres>,
81) -> Result<Option<sqlx::types::time::PrimitiveDateTime>, sqlx::Error> {
82 let record = sqlx::query_as::<_, HeerConfig>(FETCH_EPOCH_SQL)
83 .fetch_optional(executor)
84 .await?;
85
86 Ok(record.map(|row| row.epoch))
87}
88
89pub async fn fetch_active_node(
90 executor: impl Executor<'_, Database = sqlx::Postgres>,
91 node_id: u16,
92) -> Result<Option<HeerNode>, sqlx::Error> {
93 sqlx::query_as::<_, HeerNode>(FETCH_ACTIVE_NODE_SQL)
94 .bind(i32::from(node_id))
95 .fetch_optional(executor)
96 .await
97}
98
99pub async fn validate_startup(
100 executor: impl Executor<'_, Database = sqlx::Postgres>,
101 node_id: u16,
102) -> Result<HeerNode, crate::StartupError> {
103 let node = fetch_active_node(executor, node_id)
104 .await
105 .map_err(crate::StartupError::Database)?;
106
107 match node {
108 Some(node) => Ok(node),
109 None => Err(crate::StartupError::NodeNotActive(node_id)),
110 }
111}
112
113pub async fn validate_epoch(
114 executor: impl Executor<'_, Database = sqlx::Postgres>,
115) -> Result<sqlx::types::time::PrimitiveDateTime, crate::StartupError> {
116 let epoch = fetch_epoch(executor)
117 .await
118 .map_err(crate::StartupError::Database)?;
119
120 match epoch {
121 Some(epoch) => Ok(epoch),
122 None => Err(crate::StartupError::MissingEpoch),
123 }
124}
125
126pub async fn seed_default_node<'e, E>(executor: E) -> Result<(), sqlx::Error>
127where
128 E: Executor<'e, Database = sqlx::Postgres>,
129{
130 sqlx::raw_sql(SEED_SQL).execute(executor).await?;
131 Ok(())
132}
133
134pub async fn generate_heerid(
135 executor: impl Executor<'_, Database = sqlx::Postgres>,
136 node_id: u16,
137) -> Result<heeranjid::HeerId, crate::GenerateError> {
138 let raw: i64 = sqlx::query_scalar("SELECT generate_id($1)")
139 .bind(i32::from(node_id))
140 .fetch_one(executor)
141 .await
142 .map_err(map_sql_error)?;
143 heeranjid::HeerId::from_i64(raw).map_err(crate::GenerateError::InvalidHeerId)
144}
145
146pub async fn generate_ranjid(
147 executor: impl Executor<'_, Database = sqlx::Postgres>,
148 node_id: u16,
149) -> Result<heeranjid::RanjId, crate::GenerateError> {
150 let uuid: uuid::Uuid = sqlx::query_scalar("SELECT generate_ranjid($1)")
151 .bind(i32::from(node_id))
152 .fetch_one(executor)
153 .await
154 .map_err(map_sql_error)?;
155 heeranjid::RanjId::from_uuid(uuid).map_err(crate::GenerateError::InvalidRanjId)
156}
157
158pub async fn generate_heerids(
159 executor: impl Executor<'_, Database = sqlx::Postgres>,
160 node_id: u16,
161 count: i32,
162) -> Result<Vec<heeranjid::HeerId>, crate::GenerateError> {
163 let rows: Vec<i64> = sqlx::query_scalar("SELECT id FROM generate_ids($1, $2)")
164 .bind(i32::from(node_id))
165 .bind(count)
166 .fetch_all(executor)
167 .await
168 .map_err(map_sql_error)?;
169 rows.into_iter()
170 .map(|raw| heeranjid::HeerId::from_i64(raw).map_err(crate::GenerateError::InvalidHeerId))
171 .collect()
172}
173
174pub async fn set_ranj_node_id(
175 executor: impl Executor<'_, Database = sqlx::Postgres>,
176 node_id: u16,
177) -> Result<(), sqlx::Error> {
178 sqlx::query("SELECT set_heer_ranj_node_id($1)")
179 .bind(i32::from(node_id))
180 .execute(executor)
181 .await?;
182 Ok(())
183}
184
185pub async fn generate_ranjids(
186 executor: impl Executor<'_, Database = sqlx::Postgres>,
187 node_id: u16,
188 count: i32,
189) -> Result<Vec<heeranjid::RanjId>, crate::GenerateError> {
190 let rows: Vec<uuid::Uuid> = sqlx::query_scalar("SELECT id FROM generate_ranjids($1, $2)")
191 .bind(i32::from(node_id))
192 .bind(count)
193 .fetch_all(executor)
194 .await
195 .map_err(map_sql_error)?;
196 rows.into_iter()
197 .map(|uuid| heeranjid::RanjId::from_uuid(uuid).map_err(crate::GenerateError::InvalidRanjId))
198 .collect()
199}