mod bfrt;
mod core;
pub mod error;
mod protos;
pub mod register;
pub mod table;
pub mod util;
use crate::bfrt_proto::forwarding_pipeline_config::Profile;
use crate::bfrt_proto::set_forwarding_pipeline_config_request::{Action, DevInitMode};
use crate::bfrt_proto::{
ForwardingPipelineConfig, ReadResponse, SetForwardingPipelineConfigRequest,
StreamMessageRequest, StreamMessageResponse, WriteResponse,
};
use crate::error::RBFRTError;
use crate::error::RBFRTError::{
ConnectionError, GRPCError, GetForwardingPipelineError, P4ProgramError, RequestEmpty,
UnknownReadResult,
};
use crate::protos::bfrt_proto::data_field::Value;
use crate::protos::bfrt_proto::entity::Entity;
use crate::protos::bfrt_proto::stream_message_response::Update;
use crate::protos::bfrt_proto::{ReadRequest, WriteRequest};
use crate::register::Register;
use crate::table::MatchValue;
use crate::util::Digest;
use bfrt::BFRTInfo;
use bfrt_proto::bf_runtime_client::BfRuntimeClient;
use bfrt_proto::GetForwardingPipelineConfigRequest;
use bfrt_proto::TargetDevice;
use log::{debug, info, warn};
use protos::bfrt_proto;
use std::collections::HashMap;
use std::io::Read;
use std::{fs, str};
use table::{Request, RequestType, TableEntry};
use tokio::sync::Mutex;
use tokio_stream::wrappers::ReceiverStream;
use tonic::transport::Channel;
use tonic::{Response, Streaming};
const DIGEST_QUEUE_SIZE: usize = 20000;
#[allow(dead_code)]
#[allow(clippy::large_enum_variant)]
enum DispatchResult {
ReadResult {
response: Response<Streaming<ReadResponse>>,
},
WriteResult {
response: Response<WriteResponse>,
},
}
pub struct SwitchConnectionBuilder {
ip: String,
port: u16,
device_id: u32,
client_id: u32,
p4_name: Option<String>,
config: Option<String>,
}
impl SwitchConnectionBuilder {
pub fn client_id(mut self, client_id: u32) -> SwitchConnectionBuilder {
self.client_id = client_id;
self
}
pub fn device_id(mut self, device_id: u32) -> SwitchConnectionBuilder {
self.device_id = device_id;
self
}
pub fn p4_name(mut self, p4_name: &str) -> SwitchConnectionBuilder {
self.p4_name = Some(p4_name.to_owned());
self
}
pub fn config(mut self, path: &str) -> SwitchConnectionBuilder {
self.config = Some(path.to_owned());
self
}
pub async fn connect(self) -> Result<SwitchConnection, RBFRTError> {
debug!(
"Start switch connection to: http://{}:{}.",
self.ip, self.port
);
match BfRuntimeClient::connect(format!("http://{}:{}", self.ip, self.port)).await {
Ok(client) => {
let bf_client = Mutex::new(
client
.max_decoding_message_size(16 * 1024 * 1024)
.max_encoding_message_size(16 * 1024 * 1024),
);
let (request_tx, request_rx) =
tokio::sync::mpsc::channel::<StreamMessageRequest>(DIGEST_QUEUE_SIZE);
let (response_tx, mut response_rx) =
tokio::sync::mpsc::channel::<StreamMessageResponse>(DIGEST_QUEUE_SIZE);
let (digest_sender, digest_receiver) =
crossbeam_channel::bounded(DIGEST_QUEUE_SIZE);
let mut connection = SwitchConnection {
ip: self.ip,
port: self.port,
device_id: self.device_id,
client_id: self.client_id,
bf_client,
config: self.config,
bfrt_info: None,
target: TargetDevice {
device_id: self.device_id,
pipe_id: 0xffff,
direction: 0xff,
prsr_id: 0xff,
},
p4_name: self.p4_name,
send_channel: request_tx,
digest_queue: digest_receiver,
};
if connection.config.is_some() {
connection
.set_forwarding_pipeline(&connection.config.as_ref().unwrap().clone())
.await?;
}
if connection.p4_name.is_none() {
panic!("P4 name not set.")
}
connection
.subscribe(request_rx, response_tx, &mut response_rx)
.await?;
connection.bind_forwarding_pipeline().await?;
connection.bfrt_info = Some(connection.load_pipeline().await?);
connection.start_notification_thread(response_rx, digest_sender);
info!(
"Switch connection to {}:{} successful.",
connection.ip, connection.port
);
Ok(connection)
}
Err(e) => Err(ConnectionError {
ip: self.ip,
port: self.port,
orig_e: Box::new(e),
}),
}
}
}
pub struct SwitchConnection {
ip: String,
port: u16,
device_id: u32,
client_id: u32,
bf_client: Mutex<BfRuntimeClient<Channel>>,
bfrt_info: Option<BFRTInfo>,
target: TargetDevice,
p4_name: Option<String>,
send_channel: tokio::sync::mpsc::Sender<StreamMessageRequest>,
pub digest_queue: crossbeam_channel::Receiver<Digest>,
config: Option<String>,
}
impl SwitchConnection {
pub fn builder(ip: &str, port: u16) -> SwitchConnectionBuilder {
SwitchConnectionBuilder {
ip: ip.to_owned(),
port,
device_id: 0,
client_id: 1,
p4_name: None,
config: None,
}
}
#[allow(deprecated)]
async fn subscribe(
&self,
request_rx: tokio::sync::mpsc::Receiver<StreamMessageRequest>,
response_tx: tokio::sync::mpsc::Sender<StreamMessageResponse>,
response_rx: &mut tokio::sync::mpsc::Receiver<StreamMessageResponse>,
) -> Result<(), RBFRTError> {
let subscribe_req = StreamMessageRequest {
client_id: self.client_id,
update: Some(bfrt_proto::stream_message_request::Update::Subscribe(
bfrt_proto::Subscribe {
is_master: true,
device_id: self.device_id,
notifications: Some(bfrt_proto::subscribe::Notifications {
enable_learn_notifications: true,
enable_idletimeout_notifications: true,
enable_port_status_change_notifications: true,
enable_entry_active_notifications: true,
}),
status: None,
},
)),
};
let stream = ReceiverStream::new(request_rx);
let req = tonic::Request::new(stream);
let mut clone = { self.bf_client.lock().await.clone() };
tokio::spawn(async move {
let response_channel = match clone.stream_channel(req).await {
Ok(res) => res,
Err(e) => {
warn!("Failed to open stream_channel: {e}");
return;
}
};
info!("Started stream_channel");
let mut resp = response_channel.into_inner();
loop {
match resp.message().await {
Ok(Some(msg)) => match msg.clone().update.unwrap() {
Update::Subscribe(_) | Update::Digest(_) => {
if let Err(e) = response_tx.try_send(msg) {
warn!("Failed to send notification: {e}");
}
}
_ => {
warn!(
"Got a notification that is currently not supported. Will be ignored."
);
}
},
Ok(None) => {
warn!("Stream was closed by sender.");
break;
}
Err(e) => {
warn!("Error receiving notification: {e}");
break;
}
}
}
warn!("Notification channel closed.");
});
if self.send_channel.send(subscribe_req).await.is_err() {
warn!("Notification endpoint hang.")
}
let msg = response_rx.recv().await.unwrap();
match msg.update.unwrap() {
Update::Subscribe(sub) => {
if sub.status.unwrap().code != 0 {
panic!("Notification subscription failed.");
} else {
info!("Notification subscription successful.")
}
}
_ => {
panic!("Notification subscription expected.");
}
}
Ok(())
}
async fn load_pipeline(&self) -> Result<BFRTInfo, RBFRTError> {
debug!("Loading pipeline.");
match self
.bf_client
.lock()
.await
.get_forwarding_pipeline_config(GetForwardingPipelineConfigRequest {
device_id: self.device_id,
client_id: self.client_id,
})
.await
{
Ok(pipeline) => {
let msg = pipeline.into_inner();
let non_p4_config = msg.non_p4_config.unwrap();
let non_p4: BFRTInfo =
serde_json::from_slice(&non_p4_config.bfruntime_info).unwrap();
let non_p4_tables = non_p4.tables();
for v in msg.config {
if v.p4_name == self.p4_name.clone().unwrap() {
let mut tmp: BFRTInfo = serde_json::from_slice(&v.bfruntime_info).unwrap();
for t in &non_p4_tables {
tmp.add_table(t.clone());
}
return Ok(tmp);
}
}
Err(P4ProgramError {
name: self.p4_name.clone().unwrap(),
})
}
Err(e) => Err(GetForwardingPipelineError {
device_id: self.device_id,
client_id: self.client_id,
orig_e: Box::new(e),
}),
}
}
fn start_notification_thread(
&self,
mut response_rx: tokio::sync::mpsc::Receiver<StreamMessageResponse>,
digest_queue: crossbeam_channel::Sender<Digest>,
) {
let local_bfrt_info = self.bfrt_info.clone();
tokio::spawn(async move {
let bfrt_info = local_bfrt_info.unwrap();
while let Some(msg) = response_rx.recv().await {
match msg.update.unwrap() {
Update::Digest(digest) => {
let learn_filter = bfrt_info.learn_filter_get(digest.digest_id);
match learn_filter {
Ok(filter) => {
for data in digest.data {
let mut digest_fields = HashMap::new();
for field in data.fields {
let id = field.field_id;
let field_name = filter.get_data_field_name_by_id(id);
if let Ok(field_name) = field_name {
let data = field.value;
if let Some(data) = data {
match data {
Value::Stream(data) => {
digest_fields.insert(field_name, data);
}
_ => {
warn!("Not supported digest field type received.");
}
}
}
}
}
let digest = Digest {
name: filter.name.to_owned(),
data: digest_fields,
};
let _ = digest_queue.try_send(digest);
}
}
Err(err) => {
warn!("Received an error while retrieving learn filter: {err}");
}
}
}
_ => {
warn!("Received not supported notification. Only Digests are currently supported.")
}
}
}
warn!("Notification channel closed.");
});
}
fn read_file_to_bytes(&self, file_path: &str) -> Vec<u8> {
let mut file =
fs::File::open(file_path).unwrap_or_else(|_| panic!("Unable to read: {file_path}"));
let metadata = fs::metadata(file_path)
.unwrap_or_else(|_| panic!("Unable to read metadata for {file_path}."));
let mut file_buffer = vec![0; metadata.len() as usize];
file.read_exact(&mut file_buffer).expect("buffer overflow");
file_buffer
}
async fn set_forwarding_pipeline(&mut self, config_file: &str) -> Result<(), RBFRTError> {
debug!("Set forwarding pipeline.");
let file = fs::File::open(config_file)
.unwrap_or_else(|_| panic!("config file: {config_file} not readable."));
let config: core::Configuration =
serde_json::from_reader(file).expect("config file has invalid json format.");
let device = config.p4_devices.first().unwrap();
let mut forwarding_configs: Vec<ForwardingPipelineConfig> = vec![];
for program in &device.p4_programs {
self.p4_name = Some(program.program_name.clone());
let profiles: Vec<Profile> = program
.p4_pipelines
.iter()
.map(|profile| Profile {
profile_name: profile.p4_pipeline_name.to_owned(),
context: self.read_file_to_bytes(&profile.context),
binary: self.read_file_to_bytes(&profile.config),
pipe_scope: profile.pipe_scope.clone(),
})
.collect();
let forwarding_config = ForwardingPipelineConfig {
p4_name: program.program_name.to_owned(),
bfruntime_info: self.read_file_to_bytes(&program.bfrt_config),
profiles,
};
forwarding_configs.push(forwarding_config);
}
let request = SetForwardingPipelineConfigRequest {
device_id: self.device_id,
client_id: self.client_id,
action: Action::VerifyAndWarmInitBeginAndEnd.into(),
dev_init_mode: DevInitMode::FastReconfig.into(),
base_path: "".to_string(),
config: forwarding_configs,
};
let req = self
.bf_client
.lock()
.await
.set_forwarding_pipeline_config(request)
.await;
match req {
Ok(_) => Ok(()),
Err(e) => Err(GRPCError {
message: e.to_string(),
details: format!("{:?}", e.details()),
}),
}
}
async fn bind_forwarding_pipeline(&self) -> Result<(), RBFRTError> {
debug!(
"Bind forwarding pipeline: {}.",
self.p4_name.as_ref().unwrap().to_owned()
);
let forwarding_config = ForwardingPipelineConfig {
p4_name: self.p4_name.as_ref().unwrap().to_owned(),
bfruntime_info: vec![],
profiles: vec![],
};
let request = SetForwardingPipelineConfigRequest {
device_id: self.device_id,
client_id: self.client_id,
action: Action::Bind.into(),
dev_init_mode: DevInitMode::FastReconfig.into(),
base_path: "".to_string(),
config: vec![forwarding_config],
};
let req = self
.bf_client
.lock()
.await
.set_forwarding_pipeline_config(request)
.await;
match req {
Ok(_) => {
info!("Bind to forwarding pipeline successful.");
Ok(())
}
Err(e) => {
warn!("Bind forwarding pipeline failed.");
Err(GRPCError {
message: e.to_string(),
details: format!("{:?}", e.details()),
})
}
}
}
fn get_target_device(&self) -> TargetDevice {
TargetDevice {
device_id: self.target.device_id,
pipe_id: self.target.pipe_id,
direction: self.target.direction,
prsr_id: self.target.prsr_id,
}
}
pub async fn execute_operation(&self, request: Request) -> Result<(), RBFRTError> {
debug!("Execute operation {request:?}");
let req = request.request_type(RequestType::Operation);
let vec_req = vec![req];
self.dispatch_request(&vec_req).await?;
Ok(())
}
pub fn has_table(&self, name: &str) -> bool {
let t = self.bfrt_info.as_ref().unwrap().table_get(name);
t.is_ok()
}
pub async fn get_table_entries(&self, request: Request) -> Result<Vec<TableEntry>, RBFRTError> {
let entries = self.get_tables_entries(vec![request]).await?;
Ok(entries)
}
pub async fn get_tables_entries(
&self,
requests: Vec<Request>,
) -> Result<Vec<TableEntry>, RBFRTError> {
let mut veq_req = vec![];
let mut entries = vec![];
for req in requests {
veq_req.push(req.request_type(RequestType::Read));
}
match self.dispatch_request(&veq_req).await? {
DispatchResult::ReadResult { response } => {
let mut stream = response.into_inner();
let message = stream.message().await?.unwrap();
for entity in message.entities {
let entity = entity.entity.unwrap();
match &entity {
Entity::TableEntry(table_entry) => {
let table = self
.bfrt_info
.as_ref()
.unwrap()
.table_get_by_id(table_entry.table_id)?;
let entry = table.parse_read_request(entity, table.name())?;
entries.push(entry);
}
_ => {
return Err(UnknownReadResult {});
}
}
}
let _ = stream.trailers().await;
Ok(entries)
}
_ => {
panic!("Unreachable code.")
}
}
}
pub async fn write_table_entry(&self, request: Request) -> Result<(), RBFRTError> {
debug!("Write table entry {request:?}");
let req = request.request_type(RequestType::Write);
let vec_req = vec![req];
self.dispatch_request(&vec_req).await?;
Ok(())
}
pub async fn write_table_entries(&self, requests: Vec<Request>) -> Result<(), RBFRTError> {
debug!("Write table entry {requests:?}");
let req = requests
.iter()
.map(|x| x.clone().request_type(RequestType::Write))
.collect();
self.dispatch_request(&req).await?;
Ok(())
}
pub async fn update_table_entry(&self, request: Request) -> Result<(), RBFRTError> {
debug!("Update table entry {request:?}");
let req = request.request_type(RequestType::Update);
let vec_req = vec![req];
self.dispatch_request(&vec_req).await?;
Ok(())
}
pub async fn update_table_entries(&self, requests: Vec<Request>) -> Result<(), RBFRTError> {
debug!("Update table entry {requests:?}");
let req = requests
.iter()
.map(|x| x.clone().request_type(RequestType::Update))
.collect();
self.dispatch_request(&req).await?;
Ok(())
}
pub async fn delete_table_entry(&self, request: Request) -> Result<(), RBFRTError> {
debug!("Delete table entry {request:?}");
let req = request.request_type(RequestType::Delete);
let vec_req = vec![req];
self.dispatch_request(&vec_req).await?;
Ok(())
}
pub async fn delete_table_entries(&self, request: Vec<Request>) -> Result<(), RBFRTError> {
debug!("Delete table entries {request:?}");
let vec_req = request
.iter()
.map(|x| x.clone().request_type(RequestType::Delete))
.collect();
self.dispatch_request(&vec_req).await?;
Ok(())
}
pub async fn clear_table(&self, name: &str) -> Result<(), RBFRTError> {
debug!("Clear table : {name}");
let req = Request::new(name);
self.delete_table_entry(req).await?;
Ok(())
}
pub async fn clear_tables(&self, name: Vec<&str>) -> Result<(), RBFRTError> {
debug!("Clear tables : {name:?}");
let reqs: Vec<Request> = name.iter().map(|x| Request::new(x)).collect();
self.delete_table_entries(reqs).await?;
Ok(())
}
pub async fn get_register_entry(
&self,
request: register::Request,
) -> Result<Register, RBFRTError> {
debug!("Read register {request:?}");
let mut table_request = Request::new(request.get_name()).request_type(RequestType::Read);
if request.get_index().is_some() {
table_request = table_request.match_key(
"$REGISTER_INDEX",
MatchValue::exact(request.get_index().unwrap()),
);
}
let entries = self.get_table_entries(table_request).await?;
let name = request.get_name();
Ok(Register::parse_register_entries(entries, name))
}
pub async fn get_register_entries(
&self,
requests: Vec<register::Request>,
) -> Result<Register, RBFRTError> {
debug!("Read register {requests:?}");
let name = requests.first().as_ref().unwrap().get_name();
let mut req = vec![];
for request in &requests {
let table_request = Request::new(request.get_name()).request_type(RequestType::Read);
if request.get_index().is_some() {
req.push(table_request.match_key(
"$REGISTER_INDEX",
MatchValue::exact(request.get_index().unwrap()),
));
}
}
let entries = self.get_tables_entries(req).await?;
Ok(Register::parse_register_entries(entries, name))
}
pub async fn write_register_entry(&self, request: register::Request) -> Result<(), RBFRTError> {
debug!("Write register {request:?}");
let mut table_request = Request::new(request.get_name());
if request.get_index().is_none() {
return Err(RBFRTError::MissingRegisterIndex);
}
table_request = table_request.match_key(
"$REGISTER_INDEX",
MatchValue::exact(request.get_index().unwrap()),
);
for (name, value) in request.get_data() {
table_request = table_request.action_data(name, value.clone());
}
self.write_table_entry(table_request).await?;
Ok(())
}
pub async fn write_register_entries(
&self,
requests: Vec<register::Request>,
) -> Result<(), RBFRTError> {
debug!("Write register {requests:?}");
let mut write_req = vec![];
for req in &requests {
if req.get_index().is_none() {
return Err(RBFRTError::MissingRegisterIndex);
}
let mut table_request = Request::new(req.get_name()).match_key(
"$REGISTER_INDEX",
MatchValue::exact(req.get_index().unwrap()),
);
for (name, value) in req.get_data() {
table_request = table_request.action_data(name, value.clone());
}
write_req.push(table_request);
}
self.write_table_entries(write_req).await?;
Ok(())
}
async fn dispatch_request(&self, request: &Vec<Request>) -> Result<DispatchResult, RBFRTError> {
let bfrt_info = self.bfrt_info.as_ref().unwrap();
if request.is_empty() {
return Err(RequestEmpty {});
}
match request.first().as_ref().unwrap().get_type() {
RequestType::Read => {
let mut entities = vec![];
for req in request {
let table = bfrt_info.table_get(req.get_table_name())?;
let entity = table.build_read_request(req, &self.target)?;
entities.push(entity);
}
let req = ReadRequest {
target: Some(self.get_target_device()),
client_id: self.client_id,
entities,
p4_name: self.p4_name.as_ref().unwrap().to_owned(),
};
let response = self.bf_client.lock().await.read(req).await?;
Ok(DispatchResult::ReadResult { response })
}
RequestType::Write | RequestType::Update => {
let mut updates = vec![];
for req in request {
let table = bfrt_info.table_get(req.get_table_name())?;
let update = table.build_write_request(req, &self.target)?;
updates.push(update);
}
let req = WriteRequest {
target: Some(self.get_target_device()),
client_id: self.client_id,
updates,
p4_name: self.p4_name.as_ref().unwrap().to_owned(),
atomicity: 0,
};
let response = self.bf_client.lock().await.write(req).await?;
Ok(DispatchResult::WriteResult { response })
}
RequestType::Operation => {
let mut updates = vec![];
for req in request {
let table = bfrt_info.table_get(req.get_table_name())?;
let update = table.build_operation_request(req)?;
updates.push(update);
}
let req = WriteRequest {
target: Some(self.get_target_device()),
client_id: self.client_id,
updates,
p4_name: self.p4_name.as_ref().unwrap().to_owned(),
atomicity: 0,
};
let response = self.bf_client.lock().await.write(req).await?;
Ok(DispatchResult::WriteResult { response })
}
RequestType::Delete => {
let mut updates = vec![];
for req in request {
let table = bfrt_info.table_get(req.get_table_name())?;
let update = table.build_delete_request(req, &self.target)?;
updates.push(update);
}
let req = WriteRequest {
target: Some(self.get_target_device()),
client_id: self.client_id,
updates,
p4_name: self.p4_name.as_ref().unwrap().to_owned(),
atomicity: 0,
};
let response = self.bf_client.lock().await.write(req).await?;
Ok(DispatchResult::WriteResult { response })
}
}
}
}