wick_sql/sqlx/
component.rs

1use std::collections::HashMap;
2
3use flow_component::ComponentError;
4use futures::stream::BoxStream;
5use futures::StreamExt;
6use serde_json::Value;
7use sqlx::{PgPool, SqlitePool};
8use url::Url;
9use wick_config::config::components::{ComponentConfig, OperationConfig, SqlComponentConfig};
10use wick_config::config::ErrorBehavior;
11use wick_config::{ConfigValidation, Resolver};
12use wick_interface_types::{Field, Type};
13
14use crate::common::sql_wrapper::ConvertedType;
15use crate::common::{ClientConnection, Connection, DatabaseProvider};
16use crate::sqlx::{postgres, sqlite};
17use crate::{common, Error};
18
19#[derive(Debug, Clone)]
20enum CtxPool {
21  Postgres(PgPool),
22  SqlLite(SqlitePool),
23}
24
25impl CtxPool {
26  fn run_query<'a, 'b>(&'a self, querystr: &'b str, args: Vec<ConvertedType>) -> BoxStream<'a, Result<Value, Error>>
27  where
28    'b: 'a,
29  {
30    match self {
31      CtxPool::Postgres(c) => {
32        let query = postgres::make_query(querystr, args);
33        let stream = query.fetch(c).map(|res| res.map(postgres::SerMapRow::from)).map(|res| {
34          res
35            .map(|el| serde_json::to_value(el).unwrap_or(Value::Null))
36            .map_err(|e| Error::Fetch(e.to_string()))
37        });
38
39        stream.boxed()
40      }
41      CtxPool::SqlLite(c) => {
42        let query = sqlite::make_query(querystr, args);
43        let stream = query.fetch(c).map(|res| res.map(sqlite::SerMapRow::from)).map(|res| {
44          res
45            .map(|el| serde_json::to_value(el).unwrap_or(Value::Null))
46            .map_err(|e| Error::Fetch(e.to_string()))
47        });
48
49        stream.boxed()
50      }
51    }
52  }
53
54  async fn run_exec<'a, 'q>(&'a self, query: &'q str, args: Vec<ConvertedType>) -> Result<u64, Error>
55  where
56    'q: 'a,
57  {
58    let result = match self {
59      CtxPool::Postgres(c) => {
60        let query = postgres::make_query(query, args);
61        query.execute(c).await.map(|r| r.rows_affected())
62      }
63      CtxPool::SqlLite(c) => {
64        let query = sqlite::make_query(query, args);
65        query.execute(c).await.map(|r| r.rows_affected())
66      }
67    };
68    result.map_err(|e| Error::Exec(e.to_string()))
69  }
70}
71
72#[derive(Clone)]
73pub(crate) struct Context {
74  db: CtxPool,
75}
76
77impl std::fmt::Debug for Context {
78  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79    f.debug_struct("Context").finish()
80  }
81}
82
83impl Context {}
84
85/// The SQLx component.
86#[derive(Debug, Clone)]
87#[must_use]
88pub(crate) struct SqlXComponent {
89  context: Context,
90  prepared_queries: HashMap<String, String>,
91}
92
93impl SqlXComponent {
94  /// Create a new SQLx component.
95  pub(crate) async fn new(config: SqlComponentConfig, resolver: &Resolver) -> Result<Self, Error> {
96    validate(&config, resolver)?;
97    let url = common::convert_url_resource(resolver, config.resource())?;
98    let context = init_context(&config, &url).await?;
99    let mut queries = HashMap::new();
100    trace!(count=%config.operations().len(), "preparing queries");
101    for op in config.operations() {
102      queries.insert(op.name().to_owned(), op.query().to_owned());
103    }
104    Ok(Self {
105      context,
106      prepared_queries: queries,
107    })
108  }
109}
110
111#[async_trait::async_trait]
112impl DatabaseProvider for SqlXComponent {
113  fn get_statement<'a>(&'a self, id: &'a str) -> Option<&'a str> {
114    self.prepared_queries.get(id).map(|e| e.as_str())
115  }
116
117  async fn get_connection<'a, 'b>(&'a self) -> Result<Connection<'b>, Error>
118  where
119    'a: 'b,
120  {
121    Ok(Connection::new(Box::new(self.context.db.clone())))
122  }
123}
124
125#[async_trait::async_trait]
126impl ClientConnection for CtxPool {
127  async fn finish(&mut self, _behavior: ErrorBehavior) -> Result<(), Error> {
128    // todo
129    Ok(())
130  }
131
132  async fn start(&mut self, _behavior: ErrorBehavior) -> Result<(), Error> {
133    // todo
134    Ok(())
135  }
136
137  async fn handle_error(&mut self, _e: Error, _behavior: ErrorBehavior) -> Result<(), Error> {
138    // todo
139    Ok(())
140  }
141
142  async fn exec(&mut self, stmt: String, bound_args: Vec<ConvertedType>) -> Result<u64, Error> {
143    self.run_exec(&stmt, bound_args).await
144  }
145
146  async fn query<'a, 'b>(
147    &'a mut self,
148    stmt: &'a str,
149    bound_args: Vec<ConvertedType>,
150  ) -> Result<BoxStream<'b, Result<Value, Error>>, Error>
151  where
152    'a: 'b,
153  {
154    let stream = self.run_query(stmt.as_ref(), bound_args);
155    Ok(stream.boxed())
156  }
157}
158
159impl ConfigValidation for SqlXComponent {
160  type Config = SqlComponentConfig;
161  fn validate(config: &Self::Config, resolver: &Resolver) -> Result<(), ComponentError> {
162    Ok(validate(config, resolver)?)
163  }
164}
165
166fn validate(config: &SqlComponentConfig, _resolver: &Resolver) -> Result<(), Error> {
167  let bad_ops: Vec<_> = config
168    .operations()
169    .iter()
170    .filter(|op| {
171      let outputs = op.outputs();
172      outputs.len() > 1 || outputs.len() == 1 && outputs[0] != Field::new("output", Type::Object)
173    })
174    .map(|op| op.name().to_owned())
175    .collect();
176
177  if !bad_ops.is_empty() {
178    return Err(Error::InvalidOutput(bad_ops));
179  }
180
181  Ok(())
182}
183
184async fn init_client(config: &SqlComponentConfig, addr: &Url) -> Result<CtxPool, Error> {
185  let pool = match addr.scheme() {
186    "file" => CtxPool::SqlLite(
187      sqlite::connect(
188        config,
189        Some(
190          addr
191            .to_file_path()
192            .map_err(|_e| Error::SqliteConnect(format!("could not convert url {} to filepath", addr)))?
193            .to_str()
194            .unwrap(),
195        ),
196      )
197      .await?,
198    ),
199    "postgres" => CtxPool::Postgres(postgres::connect(config, addr).await?),
200    "sqlite" => {
201      if addr.host() != Some(url::Host::Domain("memory")) {
202        return Err(Error::SqliteScheme);
203      }
204      CtxPool::SqlLite(sqlite::connect(config, None).await?)
205    }
206    "mysql" => unimplemented!("MySql is not supported yet"),
207    "mssql" => unreachable!(),
208    s => return Err(Error::InvalidScheme(s.to_owned())),
209  };
210  debug!(%addr, "connected to db");
211  Ok(pool)
212}
213
214async fn init_context(config: &SqlComponentConfig, addr: &Url) -> Result<Context, Error> {
215  let client = init_client(config, addr).await?;
216  let db = client;
217
218  Ok(Context { db })
219}
220
221#[cfg(test)]
222mod test {
223  use anyhow::Result;
224  use wick_config::config::components::{
225    SqlComponentConfigBuilder,
226    SqlOperationDefinition,
227    SqlQueryOperationDefinitionBuilder,
228  };
229  use wick_config::config::{ResourceDefinition, TcpPort};
230  use wick_interface_types::{Field, Type};
231
232  use super::*;
233
234  #[test]
235  const fn test_component() {
236    const fn is_send_sync<T: Sync>() {}
237    is_send_sync::<SqlXComponent>();
238  }
239
240  #[test_logger::test(test)]
241  fn test_validate() -> Result<()> {
242    let mut config = SqlComponentConfigBuilder::default()
243      .resource("db")
244      .tls(false)
245      .build()
246      .unwrap();
247    let op = SqlQueryOperationDefinitionBuilder::default()
248      .name("test")
249      .query("select * from users where user_id = $1;")
250      .inputs([Field::new("input", Type::I32)])
251      .outputs([Field::new("output", Type::String)])
252      .arguments(["input".to_owned()])
253      .build()
254      .unwrap();
255
256    config.operations_mut().push(SqlOperationDefinition::Query(op));
257    let mut app_config = wick_config::config::AppConfiguration::default();
258    app_config.add_resource("db", ResourceDefinition::TcpPort(TcpPort::new("0.0.0.0", 11111)));
259
260    let result = validate(&config, &app_config.resolver());
261    assert!(result.is_err());
262    Ok(())
263  }
264}