encoderfile 0.6.2

Distribute and run transformer encoders with a single file.
Documentation
use crate::{
    common::model_type,
    generated::{embedding, sentence_embedding, sequence_classification, token_classification},
    runtime::AppState,
    services::{Inference, Metadata},
};

mod error;

pub trait GrpcRouter
where
    Self: Sized + Clone + Send + Sync + 'static,
{
    fn grpc_router(self) -> axum::Router;
}

pub struct GrpcService<S: Inference + Metadata> {
    state: S,
}

impl<S: Inference + Metadata> GrpcService<S> {
    pub fn new(state: S) -> Self {
        Self { state }
    }
}

macro_rules! generate_grpc_server {
    (
        $model_type:ident,
        $generated_mod:ident,
        $server_mod:ident,
        $request_path:ident,
        $response_path:ident,
        $trait_path:ident,
        $server_type:ident
    ) => {
        impl GrpcRouter for AppState<model_type::$model_type> {
            fn grpc_router(self) -> axum::Router {
                tonic::service::Routes::builder()
                    .routes()
                    .add_service($generated_mod::$server_mod::$server_type::new(
                        GrpcService::new(self),
                    ))
                    .into_axum_router()
            }
        }

        #[tonic::async_trait]
        impl $crate::generated::$generated_mod::$server_mod::$trait_path
            for GrpcService<AppState<model_type::$model_type>>
        {
            async fn predict(
                &self,
                request: tonic::Request<$crate::generated::$generated_mod::$request_path>,
            ) -> Result<
                tonic::Response<$crate::generated::$generated_mod::$response_path>,
                tonic::Status,
            > {
                Ok(tonic::Response::new(
                    self.state
                        .inference(request.into_inner())
                        .map_err(|e| e.to_tonic_status())?
                        .into(),
                ))
            }

            async fn get_model_metadata(
                &self,
                _request: tonic::Request<$crate::generated::metadata::GetModelMetadataRequest>,
            ) -> Result<
                tonic::Response<$crate::generated::metadata::GetModelMetadataResponse>,
                tonic::Status,
            > {
                Ok(tonic::Response::new(self.state.metadata().into()))
            }
        }
    };
}

generate_grpc_server!(
    Embedding,
    embedding,
    embedding_inference_server,
    EmbeddingRequest,
    EmbeddingResponse,
    EmbeddingInference,
    EmbeddingInferenceServer
);

generate_grpc_server!(
    SequenceClassification,
    sequence_classification,
    sequence_classification_inference_server,
    SequenceClassificationRequest,
    SequenceClassificationResponse,
    SequenceClassificationInference,
    SequenceClassificationInferenceServer
);

generate_grpc_server!(
    TokenClassification,
    token_classification,
    token_classification_inference_server,
    TokenClassificationRequest,
    TokenClassificationResponse,
    TokenClassificationInference,
    TokenClassificationInferenceServer
);

generate_grpc_server!(
    SentenceEmbedding,
    sentence_embedding,
    sentence_embedding_inference_server,
    SentenceEmbeddingRequest,
    SentenceEmbeddingResponse,
    SentenceEmbeddingInference,
    SentenceEmbeddingInferenceServer
);