use std::io::Write;
use diesel::{
AsChangeset, QueryResult,
backend::{Backend, DieselReserveSpecialization, SqlDialect, TrustedBackend, sql_dialect},
deserialize::FromSql,
query_builder::{
AstPass, IntoUpdateTarget, LimitOffsetClause, QueryBuilder, QueryFragment,
bind_collector::RawBytesBindCollector,
},
row::{Field, PartialRow, Row, RowIndex, RowSealed},
serialize::ToSql,
sql_types::{
BigInt, Binary, Date, Double, Float, HasSqlType, Integer, SmallInt, Text, Time, Timestamp,
TypeMetadata,
},
};
use futures_util::{
StreamExt,
stream::{self, BoxStream},
};
use super::{AsyncConnection, AsyncExecute, RunQueryDsl};
use crate::{AsyncTransaction, AsyncTransactional, UpdateAndFetchResults, methods::LoadQuery};
#[derive(Debug, Default)]
pub struct TestConnection {
expected_load: Option<Vec<Vec<String>>>,
expected_return_count: Option<usize>,
}
impl TestConnection {
pub fn expected_load(&mut self, value: Vec<Vec<String>>) {
self.expected_load = Some(value);
}
pub fn expected_return_count(&mut self, value: usize) {
self.expected_return_count = Some(value);
}
}
pub struct TestBackend;
pub struct TestRow {
data: Vec<String>,
}
pub struct TestField<'f>(&'f str);
impl TypeMetadata for TestBackend {
type TypeMetadata = ();
type MetadataLookup = ();
}
impl DieselReserveSpecialization for TestBackend {}
impl HasSqlType<Timestamp> for TestBackend {
fn metadata(_lookup: &mut Self::MetadataLookup) -> Self::TypeMetadata {}
}
impl HasSqlType<Time> for TestBackend {
fn metadata(_lookup: &mut Self::MetadataLookup) -> Self::TypeMetadata {}
}
impl HasSqlType<Date> for TestBackend {
fn metadata(_lookup: &mut Self::MetadataLookup) -> Self::TypeMetadata {}
}
impl HasSqlType<Binary> for TestBackend {
fn metadata(_lookup: &mut Self::MetadataLookup) -> Self::TypeMetadata {}
}
impl HasSqlType<Text> for TestBackend {
fn metadata(_lookup: &mut Self::MetadataLookup) -> Self::TypeMetadata {}
}
impl HasSqlType<Double> for TestBackend {
fn metadata(_lookup: &mut Self::MetadataLookup) -> Self::TypeMetadata {}
}
impl HasSqlType<Float> for TestBackend {
fn metadata(_lookup: &mut Self::MetadataLookup) -> Self::TypeMetadata {}
}
impl HasSqlType<BigInt> for TestBackend {
fn metadata(_lookup: &mut Self::MetadataLookup) -> Self::TypeMetadata {}
}
impl HasSqlType<Integer> for TestBackend {
fn metadata(_lookup: &mut Self::MetadataLookup) -> Self::TypeMetadata {}
}
impl HasSqlType<SmallInt> for TestBackend {
fn metadata(_lookup: &mut Self::MetadataLookup) -> Self::TypeMetadata {}
}
impl TrustedBackend for TestBackend {}
impl SqlDialect for TestBackend {
type ReturningClause = sql_dialect::returning_clause::PgLikeReturningClause;
type DefaultValueClauseForInsert = sql_dialect::default_value_clause::AnsiDefaultValueClause;
type OnConflictClause = sql_dialect::on_conflict_clause::DoesNotSupportOnConflictClause;
type ConcatClause = sql_dialect::concat_clause::ConcatWithPipesClause;
type InsertWithDefaultKeyword = sql_dialect::default_keyword_for_insert::IsoSqlDefaultKeyword;
type BatchInsertSupport = sql_dialect::batch_insert_support::PostgresLikeBatchInsertSupport;
type EmptyFromClauseSyntax = sql_dialect::from_clause_syntax::AnsiSqlFromClauseSyntax;
type SelectStatementSyntax = sql_dialect::select_statement_syntax::AnsiSqlSelectStatement;
type ExistsSyntax = sql_dialect::exists_syntax::AnsiSqlExistsSyntax;
type ArrayComparison = sql_dialect::array_comparison::AnsiSqlArrayComparison;
type AliasSyntax = sql_dialect::alias_syntax::AsAliasSyntax;
type WindowFrameClauseGroupSupport =
sql_dialect::window_frame_clause_group_support::NoGroupWindowFrameUnit;
type WindowFrameExclusionSupport =
sql_dialect::window_frame_exclusion_support::NoFrameFrameExclusionSupport;
type AggregateFunctionExpressions =
sql_dialect::aggregate_function_expressions::NoAggregateFunctionExpressions;
type BuiltInWindowFunctionRequireOrder =
sql_dialect::built_in_window_function_require_order::NoOrderRequired;
}
impl<A, B> QueryFragment<TestBackend> for LimitOffsetClause<A, B> {
fn walk_ast<'b>(&'b self, _pass: AstPass<'_, 'b, TestBackend>) -> QueryResult<()> {
Ok(())
}
}
pub struct TestQueryBuilder;
impl QueryBuilder<TestBackend> for TestQueryBuilder {
fn push_sql(&mut self, _sql: &str) {}
fn push_bind_param(&mut self) {}
fn push_bind_param_value_only(&mut self) {}
fn push_identifier(&mut self, _identifier: &str) -> QueryResult<()> {
Ok(())
}
fn finish(self) -> String {
String::new()
}
}
impl Backend for TestBackend {
type QueryBuilder = TestQueryBuilder;
type BindCollector<'a> = RawBytesBindCollector<Self>;
type RawValue<'a> = String;
}
impl RowSealed for TestRow {}
impl RowIndex<usize> for TestRow {
fn idx(&self, idx: usize) -> Option<usize> {
Some(idx)
}
}
impl RowIndex<&str> for TestRow {
fn idx(&self, _idx: &str) -> Option<usize> {
unimplemented!()
}
}
impl<'f> Field<'f, TestBackend> for TestField<'f> {
fn field_name(&self) -> Option<&str> {
unimplemented!()
}
fn value(&self) -> Option<<TestBackend as Backend>::RawValue<'_>> {
Some(self.0.to_string())
}
}
impl<'a> Row<'a, TestBackend> for TestRow {
type Field<'f>
= TestField<'f>
where
'a: 'f,
Self: 'f;
type InnerPartialRow = Self;
fn field_count(&self) -> usize {
self.data.len()
}
fn partial_row(
&self,
range: std::ops::Range<usize>,
) -> diesel::row::PartialRow<'_, Self::InnerPartialRow> {
PartialRow::new(self, range)
}
fn get<'b, I>(&'b self, idx: I) -> Option<Self::Field<'b>>
where
'a: 'b,
Self: diesel::row::RowIndex<I>,
{
let idx = self.idx(idx)?;
Some(TestField(&self.data[idx]))
}
}
impl AsyncExecute for TestConnection {
type Stream<'conn> = BoxStream<'conn, QueryResult<TestRow>>;
type Row<'conn> = TestRow;
type Backend = TestBackend;
async fn batch_execute(&mut self, _query: &str) -> QueryResult<()> {
unimplemented!()
}
async fn execute_returning_count<T>(&mut self, _source: T) -> QueryResult<usize>
where
T: diesel::query_builder::QueryFragment<Self::Backend>
+ diesel::query_builder::QueryId
+ Send,
{
self.expected_return_count
.take()
.ok_or(diesel::result::Error::NotFound)
}
async fn load<T>(&mut self, _source: T) -> QueryResult<Self::Stream<'_>>
where
T: diesel::query_builder::AsQuery + Send,
T::Query: diesel::query_builder::QueryFragment<Self::Backend>
+ diesel::query_builder::QueryId
+ Send,
{
let data = self
.expected_load
.take()
.ok_or(diesel::result::Error::NotFound)?;
Ok(stream::iter(data.into_iter().map(|data| Ok(TestRow { data }))).boxed())
}
}
impl AsyncConnection for TestConnection {
async fn establish(_database_url: &str) -> diesel::ConnectionResult<Self> {
Ok(Self::default())
}
fn is_broken(&self) -> bool {
false
}
}
impl FromSql<Text, TestBackend> for String {
fn from_sql(
bytes: <TestBackend as Backend>::RawValue<'_>,
) -> diesel::deserialize::Result<Self> {
Ok(bytes)
}
}
impl FromSql<Integer, TestBackend> for i32 {
fn from_sql(
bytes: <TestBackend as Backend>::RawValue<'_>,
) -> diesel::deserialize::Result<Self> {
Ok(bytes.parse()?)
}
}
impl ToSql<Integer, TestBackend> for i32 {
fn to_sql<'b>(
&'b self,
out: &mut diesel::serialize::Output<'b, '_, TestBackend>,
) -> diesel::serialize::Result {
out.write_all(&self.to_ne_bytes())?;
Ok(diesel::serialize::IsNull::No)
}
}
#[derive(Debug, Default)]
pub struct TestTransaction {
expected_load: Option<Vec<Vec<String>>>,
expected_return_count: Option<usize>,
}
impl TestTransaction {
pub fn expected_load(&mut self, value: Vec<Vec<String>>) {
self.expected_load = Some(value);
}
pub fn expected_return_count(&mut self, value: usize) {
self.expected_return_count = Some(value);
}
}
impl AsyncTransaction for TestTransaction {
async fn commit(self) -> QueryResult<()> {
Ok(())
}
async fn rollback(self) -> QueryResult<()> {
Ok(())
}
}
impl AsyncExecute for TestTransaction {
type Stream<'conn> = BoxStream<'conn, QueryResult<TestRow>>;
type Row<'conn> = TestRow;
type Backend = TestBackend;
async fn batch_execute(&mut self, _query: &str) -> QueryResult<()> {
unimplemented!()
}
async fn execute_returning_count<T>(&mut self, _source: T) -> QueryResult<usize>
where
T: diesel::query_builder::QueryFragment<Self::Backend>
+ diesel::query_builder::QueryId
+ Send,
{
self.expected_return_count
.take()
.ok_or(diesel::result::Error::NotFound)
}
async fn load<T>(&mut self, _source: T) -> QueryResult<Self::Stream<'_>>
where
T: diesel::query_builder::AsQuery + Send,
T::Query: diesel::query_builder::QueryFragment<Self::Backend>
+ diesel::query_builder::QueryId
+ Send,
{
let data = self
.expected_load
.take()
.ok_or(diesel::result::Error::NotFound)?;
Ok(stream::iter(data.into_iter().map(|data| Ok(TestRow { data }))).boxed())
}
}
impl AsyncTransactional for TestConnection {
type Transaction<'a> = TestTransaction;
async fn begin_transaction(&mut self) -> QueryResult<Self::Transaction<'_>> {
Ok(TestTransaction::default())
}
}
impl AsyncTransactional for TestTransaction {
type Transaction<'a> = TestTransaction;
async fn begin_transaction(&mut self) -> QueryResult<Self::Transaction<'_>> {
Ok(TestTransaction::default())
}
}
impl<Changes, Output, Tab, V> UpdateAndFetchResults<Changes, Output> for TestConnection
where
Output: Send,
Changes:
Copy + AsChangeset<Target = Tab> + Send + diesel::associations::Identifiable<Table = Tab>,
Tab: diesel::Table + diesel::query_dsl::methods::FindDsl<Changes::Id>,
diesel::dsl::Find<Tab, Changes::Id>: IntoUpdateTarget<Table = Tab, WhereClause = V>,
diesel::query_builder::UpdateStatement<Tab, V, Changes::Changeset>:
diesel::query_builder::AsQuery,
diesel::dsl::Update<Changes, Changes>: LoadQuery<Self, Output>,
V: Send,
Changes::Changeset: Send,
Tab::FromClause: Send,
{
async fn update_and_fetch(&mut self, changeset: Changes) -> QueryResult<Output> {
diesel::update(changeset)
.set(changeset)
.get_result(self)
.await
}
}