megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file imperative/src/impl/async_releaser.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

#include "megbrain/comp_node.h"
#include "megbrain/imperative/blob_manager.h"
#include "megbrain/imperative/resource_manager.h"
#include "megbrain/system.h"

#include "./event_pool.h"

namespace mgb {
namespace imperative {

class AsyncReleaser : public CompNodeDepedentObject {
    struct WaiterParam {
        CompNode cn;
        CompNode::Event* event;
        BlobPtr blob;
        HostTensorStorage::RawStorage storage;
    };
    class Waiter final : public AsyncQueueSC<WaiterParam, Waiter> {
        AsyncReleaser* m_par_releaser;

    public:
        // disable busy wait by set max_spin=0 to save CPU cycle
        Waiter(AsyncReleaser* releaser)
                : AsyncQueueSC<WaiterParam, Waiter>(0), m_par_releaser(releaser) {}

        void process_one_task(WaiterParam& param) {
            if (param.event->finished()) {
                param.blob.reset();
                param.storage.reset();
                EventPool::without_timer().free(param.event);
                return;
            }

            using namespace std::literals;
            std::this_thread::sleep_for(1us);
            add_task(std::move(param));
        }
        void on_async_queue_worker_thread_start() override {
            sys::set_thread_name("releaser");
        }
    };
    Waiter m_waiter{this};

protected:
    std::shared_ptr<void> on_comp_node_finalize() override {
        m_waiter.wait_task_queue_empty();
        return {};
    }

public:
    static AsyncReleaser* inst() {
        static auto* releaser = ResourceManager::create_global<AsyncReleaser>();
        return releaser;
    }

    ~AsyncReleaser() { m_waiter.wait_task_queue_empty(); }

    void add(BlobPtr blob, CompNode cn) { add(cn, std::move(blob), {}); }

    void add(const HostTensorND& hv) {
        add(hv.comp_node(), {}, hv.storage().raw_storage());
    }

    void add(CompNode cn, BlobPtr blob, HostTensorStorage::RawStorage storage = {}) {
        auto event = EventPool::without_timer().alloc(cn);
        event->record();
        m_waiter.add_task({cn, event, std::move(blob), std::move(storage)});
    }
};
}  // namespace imperative
}  // namespace mgb