1use std::collections::HashMap;
2
3use flow_component::ComponentError;
4use futures::stream::BoxStream;
5use futures::StreamExt;
6use serde_json::Value;
7use sqlx::{PgPool, SqlitePool};
8use url::Url;
9use wick_config::config::components::{ComponentConfig, OperationConfig, SqlComponentConfig};
10use wick_config::config::ErrorBehavior;
11use wick_config::{ConfigValidation, Resolver};
12use wick_interface_types::{Field, Type};
13
14use crate::common::sql_wrapper::ConvertedType;
15use crate::common::{ClientConnection, Connection, DatabaseProvider};
16use crate::sqlx::{postgres, sqlite};
17use crate::{common, Error};
18
19#[derive(Debug, Clone)]
20enum CtxPool {
21 Postgres(PgPool),
22 SqlLite(SqlitePool),
23}
24
25impl CtxPool {
26 fn run_query<'a, 'b>(&'a self, querystr: &'b str, args: Vec<ConvertedType>) -> BoxStream<'a, Result<Value, Error>>
27 where
28 'b: 'a,
29 {
30 match self {
31 CtxPool::Postgres(c) => {
32 let query = postgres::make_query(querystr, args);
33 let stream = query.fetch(c).map(|res| res.map(postgres::SerMapRow::from)).map(|res| {
34 res
35 .map(|el| serde_json::to_value(el).unwrap_or(Value::Null))
36 .map_err(|e| Error::Fetch(e.to_string()))
37 });
38
39 stream.boxed()
40 }
41 CtxPool::SqlLite(c) => {
42 let query = sqlite::make_query(querystr, args);
43 let stream = query.fetch(c).map(|res| res.map(sqlite::SerMapRow::from)).map(|res| {
44 res
45 .map(|el| serde_json::to_value(el).unwrap_or(Value::Null))
46 .map_err(|e| Error::Fetch(e.to_string()))
47 });
48
49 stream.boxed()
50 }
51 }
52 }
53
54 async fn run_exec<'a, 'q>(&'a self, query: &'q str, args: Vec<ConvertedType>) -> Result<u64, Error>
55 where
56 'q: 'a,
57 {
58 let result = match self {
59 CtxPool::Postgres(c) => {
60 let query = postgres::make_query(query, args);
61 query.execute(c).await.map(|r| r.rows_affected())
62 }
63 CtxPool::SqlLite(c) => {
64 let query = sqlite::make_query(query, args);
65 query.execute(c).await.map(|r| r.rows_affected())
66 }
67 };
68 result.map_err(|e| Error::Exec(e.to_string()))
69 }
70}
71
72#[derive(Clone)]
73pub(crate) struct Context {
74 db: CtxPool,
75}
76
77impl std::fmt::Debug for Context {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 f.debug_struct("Context").finish()
80 }
81}
82
83impl Context {}
84
85#[derive(Debug, Clone)]
87#[must_use]
88pub(crate) struct SqlXComponent {
89 context: Context,
90 prepared_queries: HashMap<String, String>,
91}
92
93impl SqlXComponent {
94 pub(crate) async fn new(config: SqlComponentConfig, resolver: &Resolver) -> Result<Self, Error> {
96 validate(&config, resolver)?;
97 let url = common::convert_url_resource(resolver, config.resource())?;
98 let context = init_context(&config, &url).await?;
99 let mut queries = HashMap::new();
100 trace!(count=%config.operations().len(), "preparing queries");
101 for op in config.operations() {
102 queries.insert(op.name().to_owned(), op.query().to_owned());
103 }
104 Ok(Self {
105 context,
106 prepared_queries: queries,
107 })
108 }
109}
110
111#[async_trait::async_trait]
112impl DatabaseProvider for SqlXComponent {
113 fn get_statement<'a>(&'a self, id: &'a str) -> Option<&'a str> {
114 self.prepared_queries.get(id).map(|e| e.as_str())
115 }
116
117 async fn get_connection<'a, 'b>(&'a self) -> Result<Connection<'b>, Error>
118 where
119 'a: 'b,
120 {
121 Ok(Connection::new(Box::new(self.context.db.clone())))
122 }
123}
124
125#[async_trait::async_trait]
126impl ClientConnection for CtxPool {
127 async fn finish(&mut self, _behavior: ErrorBehavior) -> Result<(), Error> {
128 Ok(())
130 }
131
132 async fn start(&mut self, _behavior: ErrorBehavior) -> Result<(), Error> {
133 Ok(())
135 }
136
137 async fn handle_error(&mut self, _e: Error, _behavior: ErrorBehavior) -> Result<(), Error> {
138 Ok(())
140 }
141
142 async fn exec(&mut self, stmt: String, bound_args: Vec<ConvertedType>) -> Result<u64, Error> {
143 self.run_exec(&stmt, bound_args).await
144 }
145
146 async fn query<'a, 'b>(
147 &'a mut self,
148 stmt: &'a str,
149 bound_args: Vec<ConvertedType>,
150 ) -> Result<BoxStream<'b, Result<Value, Error>>, Error>
151 where
152 'a: 'b,
153 {
154 let stream = self.run_query(stmt.as_ref(), bound_args);
155 Ok(stream.boxed())
156 }
157}
158
159impl ConfigValidation for SqlXComponent {
160 type Config = SqlComponentConfig;
161 fn validate(config: &Self::Config, resolver: &Resolver) -> Result<(), ComponentError> {
162 Ok(validate(config, resolver)?)
163 }
164}
165
166fn validate(config: &SqlComponentConfig, _resolver: &Resolver) -> Result<(), Error> {
167 let bad_ops: Vec<_> = config
168 .operations()
169 .iter()
170 .filter(|op| {
171 let outputs = op.outputs();
172 outputs.len() > 1 || outputs.len() == 1 && outputs[0] != Field::new("output", Type::Object)
173 })
174 .map(|op| op.name().to_owned())
175 .collect();
176
177 if !bad_ops.is_empty() {
178 return Err(Error::InvalidOutput(bad_ops));
179 }
180
181 Ok(())
182}
183
184async fn init_client(config: &SqlComponentConfig, addr: &Url) -> Result<CtxPool, Error> {
185 let pool = match addr.scheme() {
186 "file" => CtxPool::SqlLite(
187 sqlite::connect(
188 config,
189 Some(
190 addr
191 .to_file_path()
192 .map_err(|_e| Error::SqliteConnect(format!("could not convert url {} to filepath", addr)))?
193 .to_str()
194 .unwrap(),
195 ),
196 )
197 .await?,
198 ),
199 "postgres" => CtxPool::Postgres(postgres::connect(config, addr).await?),
200 "sqlite" => {
201 if addr.host() != Some(url::Host::Domain("memory")) {
202 return Err(Error::SqliteScheme);
203 }
204 CtxPool::SqlLite(sqlite::connect(config, None).await?)
205 }
206 "mysql" => unimplemented!("MySql is not supported yet"),
207 "mssql" => unreachable!(),
208 s => return Err(Error::InvalidScheme(s.to_owned())),
209 };
210 debug!(%addr, "connected to db");
211 Ok(pool)
212}
213
214async fn init_context(config: &SqlComponentConfig, addr: &Url) -> Result<Context, Error> {
215 let client = init_client(config, addr).await?;
216 let db = client;
217
218 Ok(Context { db })
219}
220
221#[cfg(test)]
222mod test {
223 use anyhow::Result;
224 use wick_config::config::components::{
225 SqlComponentConfigBuilder,
226 SqlOperationDefinition,
227 SqlQueryOperationDefinitionBuilder,
228 };
229 use wick_config::config::{ResourceDefinition, TcpPort};
230 use wick_interface_types::{Field, Type};
231
232 use super::*;
233
234 #[test]
235 const fn test_component() {
236 const fn is_send_sync<T: Sync>() {}
237 is_send_sync::<SqlXComponent>();
238 }
239
240 #[test_logger::test(test)]
241 fn test_validate() -> Result<()> {
242 let mut config = SqlComponentConfigBuilder::default()
243 .resource("db")
244 .tls(false)
245 .build()
246 .unwrap();
247 let op = SqlQueryOperationDefinitionBuilder::default()
248 .name("test")
249 .query("select * from users where user_id = $1;")
250 .inputs([Field::new("input", Type::I32)])
251 .outputs([Field::new("output", Type::String)])
252 .arguments(["input".to_owned()])
253 .build()
254 .unwrap();
255
256 config.operations_mut().push(SqlOperationDefinition::Query(op));
257 let mut app_config = wick_config::config::AppConfiguration::default();
258 app_config.add_resource("db", ResourceDefinition::TcpPort(TcpPort::new("0.0.0.0", 11111)));
259
260 let result = validate(&config, &app_config.resolver());
261 assert!(result.is_err());
262 Ok(())
263 }
264}