1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
// Copyright 2021 Datafuse Labs.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::binary::Encoder;
use crate::connection::Connection;
use crate::error_codes::WRONG_PASSWORD;
use crate::errors::Error;
use crate::errors::Result;
use crate::errors::ServerError;
use crate::protocols::HelloResponse;
use crate::protocols::Packet;
use crate::protocols::Stage;
use crate::protocols::SERVER_PONG;
use crate::CHContext;
pub struct Cmd {
packet: Packet,
}
impl Cmd {
pub fn create(packet: Packet) -> Self {
Self { packet }
}
pub async fn apply(self, connection: &mut Connection, ctx: &mut CHContext) -> Result<()> {
let mut encoder = Encoder::new();
match self.packet {
Packet::Ping => {
encoder.uvarint(SERVER_PONG);
}
// todo cancel
Packet::Cancel => {}
Packet::Hello(hello) => {
if !connection
.session
.authenticate(&hello.user, &hello.password, &connection.client_addr)
.await
{
let err = Error::Server(ServerError {
code: WRONG_PASSWORD,
name: "AuthenticateException".to_owned(),
message: "Unknown user or wrong password".to_owned(),
stack_trace: "".to_owned(),
});
connection.write_error(&err).await?;
return Err(err);
}
let metadata = connection.session.metadata();
let (dbms_version_major, dbms_version_minor, dbms_version_patch) =
metadata.version();
let response = HelloResponse {
dbms_name: metadata.name().to_string(),
dbms_version_major,
dbms_version_minor,
dbms_tcp_protocol_version: metadata.tcp_protocol_version(),
timezone: metadata.timezone().to_string(),
server_display_name: metadata.display_name().to_string(),
dbms_version_patch,
};
ctx.client_revision = metadata.tcp_protocol_version.min(hello.client_revision);
ctx.hello = Some(hello);
response.encode(&mut encoder, ctx.client_revision)?;
}
Packet::Query(query) => {
ctx.state.query = query.query.clone();
ctx.state.compression = query.compression;
let session = connection.session.clone();
session.execute_query(ctx, connection).await?;
if ctx.state.out.is_some() {
ctx.state.stage = Stage::InsertPrepare;
} else {
connection.write_end_of_stream().await?;
}
}
Packet::Data(block) => {
if block.is_empty() {
match ctx.state.stage {
Stage::InsertPrepare => {
ctx.state.stage = Stage::InsertStarted;
}
Stage::InsertStarted => {
// reset will reset the out, so the outer stream will break
ctx.state.reset();
ctx.state.sent_all_data.notified().await;
// wait stream finished
connection.write_end_of_stream().await?;
ctx.state.stage = Stage::Default;
}
_ => {}
}
} else if let Some(out) = &ctx.state.out {
// out.block_stream.
out.send(block).await.unwrap();
}
}
};
let bytes = encoder.get_buffer();
if !bytes.is_empty() {
connection.write_bytes(bytes).await?;
}
Ok(())
}
}